| | from __future__ import annotations |
| |
|
| | import spaces |
| | import math |
| | import random |
| | import sys |
| | from argparse import ArgumentParser |
| |
|
| | from tqdm.auto import trange |
| | import einops |
| | import gradio as gr |
| | import k_diffusion as K |
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | from einops import rearrange |
| | from omegaconf import OmegaConf |
| | from PIL import Image, ImageOps, ImageFilter |
| | from torch import autocast |
| | import cv2 |
| | import imageio |
| |
|
| | sys.path.append("./stable_diffusion") |
| |
|
| | from stable_diffusion.ldm.util import instantiate_from_config |
| |
|
| | class CFGDenoiser(nn.Module): |
| | def __init__(self, model): |
| | super().__init__() |
| | self.inner_model = model |
| |
|
| | def forward(self, z_0, z_1, sigma, cond, uncond, text_cfg_scale, image_cfg_scale): |
| | cfg_z_0 = einops.repeat(z_0, "1 ... -> n ...", n=3) |
| | cfg_z_1 = einops.repeat(z_1, "1 ... -> n ...", n=3) |
| | cfg_sigma = einops.repeat(sigma, "1 ... -> n ...", n=3) |
| | cfg_cond = { |
| | "c_crossattn": [torch.cat([cond["c_crossattn"][0], uncond["c_crossattn"][0], uncond["c_crossattn"][0]])], |
| | "c_concat": [torch.cat([cond["c_concat"][0], cond["c_concat"][0], uncond["c_concat"][0]])], |
| | } |
| | output_0, output_1 = self.inner_model(cfg_z_0, cfg_z_1, cfg_sigma, cond=cfg_cond) |
| | out_cond_0, out_img_cond_0, out_uncond_0 = output_0.chunk(3) |
| | out_cond_1, _, _ = output_1.chunk(3) |
| | return out_uncond_0 + text_cfg_scale * (out_cond_0 - out_img_cond_0) + image_cfg_scale * (out_img_cond_0 - out_uncond_0), \ |
| | out_cond_1 |
| |
|
| | def load_model_from_config(config, ckpt, vae_ckpt=None, verbose=False): |
| | print(f"Loading model from {ckpt}") |
| | pl_sd = torch.load(ckpt, map_location="cpu") |
| | if "global_step" in pl_sd: |
| | print(f"Global Step: {pl_sd['global_step']}") |
| | sd = pl_sd["state_dict"] |
| | if vae_ckpt is not None: |
| | print(f"Loading VAE from {vae_ckpt}") |
| | vae_sd = torch.load(vae_ckpt, map_location="cpu")["state_dict"] |
| | sd = { |
| | k: vae_sd[k[len("first_stage_model.") :]] if k.startswith("first_stage_model.") else v |
| | for k, v in sd.items() |
| | } |
| | model = instantiate_from_config(config.model) |
| | m, u = model.load_state_dict(sd, strict=True) |
| | if len(m) > 0 and verbose: |
| | print("missing keys:") |
| | print(m) |
| | if len(u) > 0 and verbose: |
| | print("unexpected keys:") |
| | print(u) |
| | return model |
| |
|
| | def append_dims(x, target_dims): |
| | """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" |
| | dims_to_append = target_dims - x.ndim |
| | if dims_to_append < 0: |
| | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') |
| | return x[(...,) + (None,) * dims_to_append] |
| |
|
| | class CompVisDenoiser(K.external.CompVisDenoiser): |
| | def __init__(self, model, quantize=False, device='cpu'): |
| | super().__init__(model, quantize, device) |
| | |
| | def get_eps(self, *args, **kwargs): |
| | return self.inner_model.apply_model(*args, **kwargs) |
| | |
| | def forward(self, input_0, input_1, sigma, **kwargs): |
| | c_out, c_in = [append_dims(x, input_0.ndim) for x in self.get_scalings(sigma)] |
| | |
| | |
| | eps_0, eps_1 = self.get_eps(input_0 * c_in, self.sigma_to_t(sigma.float()).cuda(), **kwargs) |
| | |
| | return input_0 + eps_0 * c_out, eps_1 |
| |
|
| | def to_d(x, sigma, denoised): |
| | """Converts a denoiser output to a Karras ODE derivative.""" |
| | return (x - denoised) / append_dims(sigma, x.ndim) |
| |
|
| | def default_noise_sampler(x): |
| | return lambda sigma, sigma_next: torch.randn_like(x) |
| |
|
| | def get_ancestral_step(sigma_from, sigma_to, eta=1.): |
| | """Calculates the noise level (sigma_down) to step down to and the amount |
| | of noise to add (sigma_up) when doing an ancestral sampling step.""" |
| | if not eta: |
| | return sigma_to, 0. |
| | sigma_up = min(sigma_to, eta * (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5) |
| | sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5 |
| | return sigma_down, sigma_up |
| |
|
| | def decode_mask(mask, height = 256, width = 256): |
| | mask = nn.functional.interpolate(mask, size=(height, width), mode="bilinear", align_corners=False) |
| | mask = torch.where(mask > 0, 1, -1) |
| | mask = torch.clamp((mask + 1.0) / 2.0, min=0.0, max=1.0) |
| | mask = 255.0 * rearrange(mask, "1 c h w -> h w c") |
| | mask = torch.cat([mask, mask, mask], dim=-1) |
| | mask = mask.type(torch.uint8).cpu().numpy() |
| | return mask |
| |
|
| | def sample_euler_ancestral(model, x_0, x_1, sigmas, height, width, extra_args=None, disable=None, eta=1., s_noise=1., noise_sampler=None): |
| | """Ancestral sampling with Euler method steps.""" |
| | extra_args = {} if extra_args is None else extra_args |
| | noise_sampler = default_noise_sampler(x_0) if noise_sampler is None else noise_sampler |
| | s_in = x_0.new_ones([x_0.shape[0]]) |
| |
|
| | mask_list = [] |
| | image_list = [] |
| | for i in trange(len(sigmas) - 1, disable=disable): |
| | denoised_0, denoised_1 = model(x_0, x_1, sigmas[i] * s_in, **extra_args) |
| | image_list.append(denoised_0) |
| |
|
| | sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) |
| | d_0 = to_d(x_0, sigmas[i], denoised_0) |
| | |
| | |
| | dt = sigma_down - sigmas[i] |
| | x_0 = x_0 + d_0 * dt |
| |
|
| | if sigmas[i + 1] > 0: |
| | x_0 = x_0 + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up |
| |
|
| | x_1 = denoised_1 |
| | mask_list.append(decode_mask(x_1, height, width)) |
| | |
| | image_list = torch.cat(image_list, dim=0) |
| |
|
| | return x_0, x_1, image_list, mask_list |
| |
|
| | parser = ArgumentParser() |
| | parser.add_argument("--resolution", default=512, type=int) |
| | parser.add_argument("--config", default="configs/generate_diffree.yaml", type=str) |
| | parser.add_argument("--ckpt", default="checkpoints/epoch=000041-step=000010999.ckpt", type=str) |
| | parser.add_argument("--vae-ckpt", default=None, type=str) |
| | args = parser.parse_args() |
| |
|
| | config = OmegaConf.load(args.config) |
| | model = load_model_from_config(config, args.ckpt, args.vae_ckpt) |
| | model.eval().cuda() |
| | model_wrap = CompVisDenoiser(model) |
| | model_wrap_cfg = CFGDenoiser(model_wrap) |
| | null_token = model.get_learned_conditioning([""]) |
| |
|
| | @spaces.GPU(duration=30) |
| | def generate( |
| | input_image: Image.Image, |
| | instruction: str, |
| | steps: int, |
| | randomize_seed: bool, |
| | seed: int, |
| | randomize_cfg: bool, |
| | text_cfg_scale: float, |
| | image_cfg_scale: float, |
| | weather_close_video: bool, |
| | decode_image_batch: int |
| | ): |
| | seed = random.randint(0, 100000) if randomize_seed else seed |
| | text_cfg_scale = round(random.uniform(6.0, 9.0), ndigits=2) if randomize_cfg else text_cfg_scale |
| | image_cfg_scale = round(random.uniform(1.2, 1.8), ndigits=2) if randomize_cfg else image_cfg_scale |
| |
|
| | width, height = input_image.size |
| | factor = args.resolution / max(width, height) |
| | factor = math.ceil(min(width, height) * factor / 64) * 64 / min(width, height) |
| | width = int((width * factor) // 64) * 64 |
| | height = int((height * factor) // 64) * 64 |
| | input_image = ImageOps.fit(input_image, (width, height), method=Image.Resampling.LANCZOS) |
| | input_image_copy = input_image.convert("RGB") |
| |
|
| | if instruction == "": |
| | return [input_image, seed] |
| | |
| | model.cuda() |
| | with torch.no_grad(), autocast("cuda"), model.ema_scope(): |
| | cond = {} |
| | cond["c_crossattn"] = [model.get_learned_conditioning([instruction]).to(model.device)] |
| | input_image = 2 * torch.tensor(np.array(input_image)).float() / 255 - 1 |
| | input_image = rearrange(input_image, "h w c -> 1 c h w").to(model.device) |
| | cond["c_concat"] = [model.encode_first_stage(input_image).mode().to(model.device)] |
| |
|
| | uncond = {} |
| | uncond["c_crossattn"] = [null_token.to(model.device)] |
| | uncond["c_concat"] = [torch.zeros_like(cond["c_concat"][0])] |
| | |
| |
|
| | sigmas = model_wrap.get_sigmas(steps).to(model.device) |
| |
|
| | extra_args = { |
| | "cond": cond, |
| | "uncond": uncond, |
| | "text_cfg_scale": text_cfg_scale, |
| | "image_cfg_scale": image_cfg_scale, |
| | } |
| | torch.manual_seed(seed) |
| | z_0 = torch.randn_like(cond["c_concat"][0]).to(model.device) * sigmas[0] |
| | z_1 = torch.randn_like(cond["c_concat"][0]).to(model.device) * sigmas[0] |
| | |
| | z_0, z_1, image_list, mask_list = sample_euler_ancestral(model_wrap_cfg, z_0, z_1, sigmas, height, width, extra_args=extra_args) |
| | |
| | x_0 = model.decode_first_stage(z_0) |
| |
|
| | if model.first_stage_downsample: |
| | x_1 = nn.functional.interpolate(z_1, size=(height, width), mode="bilinear", align_corners=False) |
| | x_1 = torch.where(x_1 > 0, 1, -1) |
| | else: |
| | x_1 = model.decode_first_stage(z_1) |
| | |
| | x_0 = torch.clamp((x_0 + 1.0) / 2.0, min=0.0, max=1.0) |
| | x_1 = torch.clamp((x_1 + 1.0) / 2.0, min=0.0, max=1.0) |
| | x_0 = 255.0 * rearrange(x_0, "1 c h w -> h w c") |
| | x_1 = 255.0 * rearrange(x_1, "1 c h w -> h w c") |
| | x_1 = torch.cat([x_1, x_1, x_1], dim=-1) |
| | edited_image = Image.fromarray(x_0.type(torch.uint8).cpu().numpy()) |
| | edited_mask = Image.fromarray(x_1.type(torch.uint8).cpu().numpy()) |
| |
|
| | image_video_path = None |
| | if not weather_close_video: |
| | image_video = [] |
| | |
| | for i in range(0, len(image_list), decode_image_batch): |
| | if i + decode_image_batch < len(image_list): |
| | tmp_image_list = image_list[i:i+decode_image_batch] |
| | else: |
| | tmp_image_list = image_list[i:] |
| | tmp_image_list = model.decode_first_stage(tmp_image_list) |
| | tmp_image_list = torch.clamp((tmp_image_list + 1.0) / 2.0, min=0.0, max=1.0) |
| | tmp_image_list = 255.0 * rearrange(tmp_image_list, "b c h w -> b h w c") |
| | tmp_image_list = tmp_image_list.type(torch.uint8).cpu().numpy() |
| | |
| | for image in tmp_image_list: |
| | image_video.append(image) |
| |
|
| | image_video_path = "image.mp4" |
| | fps = 30 |
| | with imageio.get_writer(image_video_path, fps=fps) as video: |
| | for image in image_video: |
| | video.append_data(image) |
| |
|
| | edited_mask_copy = edited_mask.copy() |
| | kernel = np.ones((3, 3), np.uint8) |
| | edited_mask = cv2.dilate(np.array(edited_mask), kernel, iterations=3) |
| | edited_mask = Image.fromarray(edited_mask) |
| |
|
| | m_img = edited_mask.filter(ImageFilter.GaussianBlur(radius=3)) |
| | m_img = np.asarray(m_img).astype('float') / 255.0 |
| | img_np = np.asarray(input_image_copy).astype('float') / 255.0 |
| | ours_np = np.asarray(edited_image).astype('float') / 255.0 |
| |
|
| | mix_image_np = m_img * ours_np + (1 - m_img) * img_np |
| | mix_image = Image.fromarray((mix_image_np * 255).astype(np.uint8)).convert('RGB') |
| |
|
| |
|
| | red = np.array(mix_image).astype('float') * 1 |
| | red[:, :, 0] = 180.0 |
| | red[:, :, 2] = 0 |
| | red[:, :, 1] = 0 |
| | mix_result_with_red_mask = np.array(mix_image) |
| | mix_result_with_red_mask = Image.fromarray( |
| | (mix_result_with_red_mask.astype('float') * (1 - m_img.astype('float') / 2.0) + |
| | m_img.astype('float') / 2.0 * red).astype('uint8')) |
| |
|
| |
|
| | mask_video_path = "mask.mp4" |
| | fps = 30 |
| | with imageio.get_writer(mask_video_path, fps=fps) as video: |
| | for image in mask_list: |
| | video.append_data(image) |
| |
|
| | return [int(seed), text_cfg_scale, image_cfg_scale, edited_image, mix_image, edited_mask_copy, mask_video_path, image_video_path, input_image_copy, mix_result_with_red_mask] |
| |
|
| |
|
| | def single_generation(model_wrap_cfg, input_image_copy, instruction, steps, seed, text_cfg_scale, image_cfg_scale, height, width): |
| | model.cuda() |
| | with torch.no_grad(), autocast("cuda"), model.ema_scope(): |
| | cond = {} |
| | input_image_torch = 2 * torch.tensor(np.array(input_image_copy.to(model.device))).float() / 255 - 1 |
| | input_image_torch = rearrange(input_image_torch, "h w c -> 1 c h w").to(model.device) |
| | cond["c_crossattn"] = [model.get_learned_conditioning([instruction]).to(model.device)] |
| | cond["c_concat"] = [model.encode_first_stage(input_image_torch).mode().to(model.device)] |
| |
|
| | uncond = {} |
| | uncond["c_crossattn"] = [null_token.to(model.device)] |
| | uncond["c_concat"] = [torch.zeros_like(cond["c_concat"][0])] |
| |
|
| | sigmas = model_wrap.get_sigmas(steps).to(model.device) |
| |
|
| | extra_args = { |
| | "cond": cond, |
| | "uncond": uncond, |
| | "text_cfg_scale": text_cfg_scale, |
| | "image_cfg_scale": image_cfg_scale, |
| | } |
| | torch.manual_seed(seed) |
| | z_0 = torch.randn_like(cond["c_concat"][0]).to(model.device) * sigmas[0] |
| | z_1 = torch.randn_like(cond["c_concat"][0]).to(model.device) * sigmas[0] |
| | |
| | z_0, z_1, _, _ = sample_euler_ancestral(model_wrap_cfg, z_0, z_1, sigmas, height, width, extra_args=extra_args) |
| | |
| | x_0 = model.decode_first_stage(z_0) |
| | |
| | x_1 = nn.functional.interpolate(z_1, size=(height, width), mode="bilinear", align_corners=False) |
| | x_1 = torch.where(x_1 > 0, 1, -1) |
| |
|
| | x_1_mean = torch.sum(x_1).item()/x_1.numel() |
| |
|
| | return x_0, x_1, x_1_mean |
| |
|
| |
|
| | @spaces.GPU(duration=150) |
| | def generate_list( |
| | input_image: Image.Image, |
| | generate_list: str, |
| | steps: int, |
| | randomize_seed: bool, |
| | seed: int, |
| | randomize_cfg: bool, |
| | text_cfg_scale: float, |
| | image_cfg_scale: float, |
| | weather_close_video: bool, |
| | decode_image_batch: int |
| | ): |
| | generate_list = generate_list.split('\n') |
| | |
| | generate_list = [element for element in generate_list if element] |
| |
|
| | seed = random.randint(0, 100000) if randomize_seed else seed |
| | text_cfg_scale = round(random.uniform(6.0, 9.0), ndigits=2) if randomize_cfg else text_cfg_scale |
| | image_cfg_scale = round(random.uniform(1.2, 1.8), ndigits=2) if randomize_cfg else image_cfg_scale |
| |
|
| | width, height = input_image.size |
| | factor = args.resolution / max(width, height) |
| | factor = math.ceil(min(width, height) * factor / 64) * 64 / min(width, height) |
| | width = int((width * factor) // 64) * 64 |
| | height = int((height * factor) // 64) * 64 |
| | input_image = ImageOps.fit(input_image, (width, height), method=Image.Resampling.LANCZOS) |
| |
|
| | if len(generate_list) == 0: |
| | return [input_image, seed] |
| | |
| | model.cuda() |
| | image_video = [np.array(input_image).astype(np.uint8)] |
| | generate_index = 0 |
| | retry_number = 0 |
| | max_retry = 10 |
| | input_image_copy = input_image.convert("RGB") |
| | while generate_index < len(generate_list): |
| | print(f'generate_index: {str(generate_index)}') |
| | instruction = generate_list[generate_index] |
| | |
| | |
| | with torch.no_grad(), autocast("cuda"), model.ema_scope(): |
| | cond = {} |
| | input_image_torch = 2 * torch.tensor(np.array(input_image_copy)).float() / 255 - 1 |
| | input_image_torch = rearrange(input_image_torch, "h w c -> 1 c h w").to(model.device) |
| | cond["c_crossattn"] = [model.get_learned_conditioning([instruction]).to(model.device)] |
| | cond["c_concat"] = [model.encode_first_stage(input_image_torch).mode().to(model.device)] |
| |
|
| | uncond = {} |
| | uncond["c_crossattn"] = [null_token.to(model.device)] |
| | uncond["c_concat"] = [torch.zeros_like(cond["c_concat"][0])] |
| |
|
| | sigmas = model_wrap.get_sigmas(steps).to(model.device) |
| |
|
| | extra_args = { |
| | "cond": cond, |
| | "uncond": uncond, |
| | "text_cfg_scale": text_cfg_scale, |
| | "image_cfg_scale": image_cfg_scale, |
| | } |
| | torch.manual_seed(seed) |
| | z_0 = torch.randn_like(cond["c_concat"][0]).to(model.device) * sigmas[0] |
| | z_1 = torch.randn_like(cond["c_concat"][0]).to(model.device) * sigmas[0] |
| | |
| | z_0, z_1, _, _ = sample_euler_ancestral(model_wrap_cfg, z_0, z_1, sigmas, height, width, extra_args=extra_args) |
| | |
| | x_0 = model.decode_first_stage(z_0) |
| | |
| | x_1 = nn.functional.interpolate(z_1, size=(height, width), mode="bilinear", align_corners=False) |
| | x_1 = torch.where(x_1 > 0, 1, -1) |
| |
|
| | x_1_mean = torch.sum(x_1).item()/x_1.numel() |
| | |
| | if x_1_mean < -0.99: |
| | seed += 1 |
| | retry_number +=1 |
| | if retry_number > max_retry: |
| | generate_index += 1 |
| | continue |
| | else: |
| | generate_index += 1 |
| | |
| | x_0 = torch.clamp((x_0 + 1.0) / 2.0, min=0.0, max=1.0) |
| | x_1 = torch.clamp((x_1 + 1.0) / 2.0, min=0.0, max=1.0) |
| | x_0 = 255.0 * rearrange(x_0, "1 c h w -> h w c") |
| | x_1 = 255.0 * rearrange(x_1, "1 c h w -> h w c") |
| | x_1 = torch.cat([x_1, x_1, x_1], dim=-1) |
| | edited_image = Image.fromarray(x_0.type(torch.uint8).cpu().numpy()) |
| | edited_mask = Image.fromarray(x_1.type(torch.uint8).cpu().numpy()) |
| |
|
| | |
| | edited_mask_copy = edited_mask.copy() |
| | kernel = np.ones((3, 3), np.uint8) |
| | edited_mask = cv2.dilate(np.array(edited_mask), kernel, iterations=3) |
| | edited_mask = Image.fromarray(edited_mask) |
| |
|
| | m_img = edited_mask.filter(ImageFilter.GaussianBlur(radius=3)) |
| | m_img = np.asarray(m_img).astype('float') / 255.0 |
| | img_np = np.asarray(input_image_copy).astype('float') / 255.0 |
| | ours_np = np.asarray(edited_image).astype('float') / 255.0 |
| |
|
| | mix_image_np = m_img * ours_np + (1 - m_img) * img_np |
| | |
| | image_video.append((mix_image_np * 255).astype(np.uint8)) |
| | mix_image = Image.fromarray((mix_image_np * 255).astype(np.uint8)).convert('RGB') |
| | |
| | mix_result_with_red_mask = None |
| | mask_video_path = None |
| | image_video_path = None |
| | edited_mask_copy = None |
| | |
| | if generate_index == len(generate_list): |
| | image_video_path = "image.mp4" |
| | fps = 2 |
| | with imageio.get_writer(image_video_path, fps=fps) as video: |
| | for image in image_video: |
| | video.append_data(image) |
| |
|
| | yield [int(seed), text_cfg_scale, image_cfg_scale, edited_image, mix_image, edited_mask_copy, mask_video_path, image_video_path, input_image, mix_result_with_red_mask] |
| |
|
| | input_image_copy = mix_image |
| |
|
| |
|
| | def reset(): |
| | return [100, "Randomize Seed", 1372, "Fix CFG", 7.5, 1.5, None, None, None, None, None, None, None, "Close Image Video", 10] |
| |
|
| |
|
| | def get_example(): |
| | return [ |
| | ["example_images/dufu.png", "", "black and white suit\nsunglasses\nblue medical mask\nyellow schoolbag\nred bow tie\nbrown high-top hat", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5], |
| | ["example_images/girl.jpeg", "", "reflective sunglasses\nshiny golden crown\ndiamond necklace\ngorgeous yellow gown\nbeautiful tattoo", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5], |
| | ["example_images/dufu.png", "black and white suit", "", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5], |
| | ["example_images/girl.jpeg", "reflective sunglasses", "", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5], |
| | ["example_images/road_sign.png", "stop sign", "", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5], |
| | ["example_images/dufu.png", "blue medical mask", "", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5], |
| | ["example_images/people_standing.png", "dark green pleated skirt", "", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5], |
| | ["example_images/girl.jpeg", "shiny golden crown", "", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5], |
| | ["example_images/dufu.png", "sunglasses", "", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5], |
| | ["example_images/girl.jpeg", "diamond necklace", "", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5], |
| | ["example_images/iron_man.jpg", "sunglasses", "", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5], |
| | ["example_images/girl.jpeg", "the queen's crown", "", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5], |
| | ["example_images/girl.jpeg", "gorgeous yellow gown", "", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5], |
| | ] |
| |
|
| | with gr.Blocks(css="footer {visibility: hidden}") as demo: |
| | with gr.Row(): |
| | gr.Markdown( |
| | "<div align='center'><font size='14'>Diffree: Text-Guided Shape Free Object Inpainting with Diffusion Model</font></div>" |
| | ) |
| | with gr.Row(): |
| | gr.Markdown( |
| | """ |
| | <div align='center'> |
| | <a href="https://opengvlab.github.io/Diffree/"><u>[🌐Project Page]</u></a> |
| | |
| | <a href="https://drive.google.com/file/d/1AdIPA5TK5LB1tnqqZuZ9GsJ6Zzqo2ua6/view"><u>[🎥 Video]</u></a> |
| | |
| | <a href="https://github.com/OpenGVLab/Diffree"><u>[🔍 Code]</u></a> |
| | |
| | <a href="https://arxiv.org/pdf/2407.16982"><u>[📜 Arxiv]</u></a> |
| | </div> |
| | """ |
| | ) |
| |
|
| | with gr.Row(): |
| | with gr.Column(scale=1, min_width=100): |
| | with gr.Row(): |
| | input_image = gr.Image(label="Input Image", type="pil", interactive=True) |
| | with gr.Row(): |
| | instruction = gr.Textbox(lines=1, label="Single object description", interactive=True) |
| | with gr.Row(): |
| | reset_button = gr.Button("Reset") |
| | generate_button = gr.Button("Generate") |
| | with gr.Row(): |
| | list_input = gr.Textbox(label="Input List", placeholder="Enter one item per line\nThe generation time increases with the quantity.", lines=10) |
| | with gr.Row(): |
| | list_generate_button = gr.Button("List Generate") |
| | with gr.Row(): |
| | steps = gr.Number(value=100, precision=0, label="Steps", interactive=True) |
| | randomize_seed = gr.Radio( |
| | ["Fix Seed", "Randomize Seed"], |
| | value="Randomize Seed", |
| | type="index", |
| | label="Seed Selection", |
| | show_label=False, |
| | interactive=True, |
| | ) |
| | seed = gr.Number(value=1372, precision=0, label="Seed", interactive=True) |
| | randomize_cfg = gr.Radio( |
| | ["Fix CFG", "Randomize CFG"], |
| | value="Fix CFG", |
| | type="index", |
| | label="CFG Selection", |
| | show_label=False, |
| | interactive=True, |
| | ) |
| | text_cfg_scale = gr.Number(value=7.5, label=f"Text CFG", interactive=True) |
| | image_cfg_scale = gr.Number(value=1.5, label=f"Image CFG", interactive=True) |
| | with gr.Column(scale=1, min_width=100): |
| | with gr.Column(): |
| | mix_image = gr.Image(label=f"Mix Image", type="pil", interactive=False) |
| | with gr.Column(): |
| | edited_mask = gr.Image(label=f"Output Mask", type="pil", interactive=False) |
| | |
| | with gr.Accordion('👇 Click to see more (includes generation process per object for list generation and per step for single generation)', open=False): |
| | with gr.Row(): |
| | weather_close_video = gr.Radio( |
| | ["Show Image Video", "Close Image Video"], |
| | value="Close Image Video", |
| | type="index", |
| | label="Image Generation Process Selection For Single Generation (close for faster generation)", |
| | interactive=True, |
| | ) |
| | decode_image_batch = gr.Number(value=10, precision=0, label="Decode Image Batch (<steps)", interactive=True) |
| | with gr.Row(): |
| | image_video = gr.Video(label="Image Video of Generation Process") |
| | mask_video = gr.Video(label="Mask Video of Generation Process") |
| | with gr.Row(): |
| | original_image = gr.Image(label=f"Original Image", type="pil", interactive=False) |
| | edited_image = gr.Image(label=f"Output Image", type="pil", interactive=False) |
| | mix_result_with_red_mask = gr.Image(label=f"Mix Image With Red Mask", type="pil", interactive=False) |
| | |
| | with gr.Row(): |
| | gr.Examples( |
| | examples=get_example(), |
| | inputs=[input_image, instruction, list_input, steps, randomize_seed, seed, randomize_cfg, text_cfg_scale, image_cfg_scale, weather_close_video, decode_image_batch], |
| | fn=None, |
| | outputs=[seed, text_cfg_scale, image_cfg_scale, edited_image, mix_image, edited_mask, mask_video, image_video, original_image, mix_result_with_red_mask], |
| | cache_examples = False |
| | ) |
| | |
| | generate_button.click( |
| | fn=generate, |
| | inputs=[ |
| | input_image, |
| | instruction, |
| | steps, |
| | randomize_seed, |
| | seed, |
| | randomize_cfg, |
| | text_cfg_scale, |
| | image_cfg_scale, |
| | weather_close_video, |
| | decode_image_batch |
| | ], |
| | outputs=[seed, text_cfg_scale, image_cfg_scale, edited_image, mix_image, edited_mask, mask_video, image_video, original_image, mix_result_with_red_mask], |
| | ) |
| |
|
| | list_generate_button.click( |
| | fn=generate_list, |
| | inputs=[ |
| | input_image, |
| | list_input, |
| | steps, |
| | randomize_seed, |
| | seed, |
| | randomize_cfg, |
| | text_cfg_scale, |
| | image_cfg_scale, |
| | weather_close_video, |
| | decode_image_batch |
| | ], |
| | outputs=[seed, text_cfg_scale, image_cfg_scale, edited_image, mix_image, edited_mask, mask_video, image_video, original_image, mix_result_with_red_mask], |
| | ) |
| |
|
| | reset_button.click( |
| | fn=reset, |
| | inputs=[], |
| | outputs=[steps, randomize_seed, seed, randomize_cfg, text_cfg_scale, image_cfg_scale, edited_image, mix_image, edited_mask, mask_video, image_video, original_image, mix_result_with_red_mask, weather_close_video, decode_image_batch], |
| | ) |
| |
|
| |
|
| | |
| | |
| |
|
| |
|
| | |
| | demo.queue().launch() |
| |
|