| | import argparse |
| | import os |
| | import os.path as osp |
| | import torchvision.transforms.functional as TF |
| | import torch.nn.functional as F |
| | import cv2 |
| | import tempfile |
| | import imageio |
| | import torch |
| | import decord |
| | from PIL import Image |
| | import numpy as np |
| | from rembg import remove, new_session |
| | import random |
| | import ffmpeg |
| | import os |
| | import tempfile |
| | import subprocess |
| | import json |
| | import time |
| | from functools import lru_cache |
| | os.environ["U2NET_HOME"] = os.path.join(os.getcwd(), "ckpts", "rembg") |
| |
|
| |
|
| | from PIL import Image |
| | video_info_cache = [] |
| | def seed_everything(seed: int): |
| | random.seed(seed) |
| | np.random.seed(seed) |
| | torch.manual_seed(seed) |
| | if torch.cuda.is_available(): |
| | torch.cuda.manual_seed(seed) |
| | if torch.backends.mps.is_available(): |
| | torch.mps.manual_seed(seed) |
| |
|
| | def has_video_file_extension(filename): |
| | extension = os.path.splitext(filename)[-1].lower() |
| | return extension in [".mp4", ".mkv"] |
| |
|
| | def has_image_file_extension(filename): |
| | extension = os.path.splitext(filename)[-1].lower() |
| | return extension in [".png", ".jpg", ".jpeg", ".bmp", ".gif", ".webp", ".tif", ".tiff", ".jfif", ".pjpeg"] |
| |
|
| | def has_audio_file_extension(filename): |
| | extension = os.path.splitext(filename)[-1].lower() |
| | return extension in [".wav", ".mp3", ".aac"] |
| |
|
| | def resample(video_fps, video_frames_count, max_target_frames_count, target_fps, start_target_frame ): |
| | import math |
| |
|
| | video_frame_duration = 1 /video_fps |
| | target_frame_duration = 1 / target_fps |
| | |
| | target_time = start_target_frame * target_frame_duration |
| | frame_no = math.ceil(target_time / video_frame_duration) |
| | cur_time = frame_no * video_frame_duration |
| | frame_ids =[] |
| | while True: |
| | if max_target_frames_count != 0 and len(frame_ids) >= max_target_frames_count : |
| | break |
| | diff = round( (target_time -cur_time) / video_frame_duration , 5) |
| | add_frames_count = math.ceil( diff) |
| | frame_no += add_frames_count |
| | if frame_no >= video_frames_count: |
| | break |
| | frame_ids.append(frame_no) |
| | cur_time += add_frames_count * video_frame_duration |
| | target_time += target_frame_duration |
| | frame_ids = frame_ids[:max_target_frames_count] |
| | return frame_ids |
| |
|
| | import os |
| | from datetime import datetime |
| |
|
| | def get_file_creation_date(file_path): |
| | |
| | if os.name == 'nt': |
| | return datetime.fromtimestamp(os.path.getctime(file_path)) |
| | |
| | else: |
| | stat = os.stat(file_path) |
| | return datetime.fromtimestamp(stat.st_birthtime if hasattr(stat, 'st_birthtime') else stat.st_mtime) |
| |
|
| | def sanitize_file_name(file_name, rep =""): |
| | return file_name.replace("/",rep).replace("\\",rep).replace("*",rep).replace(":",rep).replace("|",rep).replace("?",rep).replace("<",rep).replace(">",rep).replace("\"",rep).replace("\n",rep).replace("\r",rep) |
| |
|
| | def truncate_for_filesystem(s, max_bytes=None): |
| | if max_bytes is None: |
| | max_bytes = 50 if os.name == 'nt'else 100 |
| |
|
| | if len(s.encode('utf-8')) <= max_bytes: return s |
| | l, r = 0, len(s) |
| | while l < r: |
| | m = (l + r + 1) // 2 |
| | if len(s[:m].encode('utf-8')) <= max_bytes: l = m |
| | else: r = m - 1 |
| | return s[:l] |
| |
|
| | def get_default_workers(): |
| | return os.cpu_count()/ 2 |
| |
|
| | def process_images_multithread(image_processor, items, process_type, wrap_in_list = True, max_workers: int = os.cpu_count()/ 2, in_place = False) : |
| | if not items: |
| | return [] |
| |
|
| | import concurrent.futures |
| | start_time = time.time() |
| | |
| | if process_type in ["prephase", "upsample"]: |
| | if wrap_in_list : |
| | items_list = [ [img] for img in items] |
| | else: |
| | items_list = items |
| | if max_workers == 1: |
| | results = [] |
| | for idx, item in enumerate(items): |
| | item = image_processor(item) |
| | results.append(item) |
| | if wrap_in_list: items_list[idx] = None |
| | if in_place: items[idx] = item[0] if wrap_in_list else item |
| | else: |
| | with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: |
| | futures = {executor.submit(image_processor, img): idx for idx, img in enumerate(items_list)} |
| | results = [None] * len(items_list) |
| | for future in concurrent.futures.as_completed(futures): |
| | idx = futures[future] |
| | results[idx] = future.result() |
| | if wrap_in_list: items_list[idx] = None |
| | if in_place: |
| | items[idx] = results[idx][0] if wrap_in_list else results[idx] |
| |
|
| | if wrap_in_list: |
| | results = [ img[0] for img in results] |
| | else: |
| | results= image_processor(items) |
| |
|
| | end_time = time.time() |
| | |
| |
|
| | return results |
| | @lru_cache(maxsize=100) |
| | def get_video_info(video_path): |
| | global video_info_cache |
| | import cv2 |
| | cap = cv2.VideoCapture(video_path) |
| | |
| | |
| | fps = round(cap.get(cv2.CAP_PROP_FPS)) |
| | |
| | |
| | width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
| | height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
| | frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
| | cap.release() |
| | |
| | return fps, width, height, frame_count |
| |
|
| | def get_video_frame(file_name: str, frame_no: int, return_last_if_missing: bool = False, target_fps = None, return_PIL = True) -> torch.Tensor: |
| | """Extract nth frame from video as PyTorch tensor normalized to [-1, 1].""" |
| | cap = cv2.VideoCapture(file_name) |
| | |
| | if not cap.isOpened(): |
| | raise ValueError(f"Cannot open video: {file_name}") |
| | |
| | total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
| | fps = round(cap.get(cv2.CAP_PROP_FPS)) |
| | if target_fps is not None: |
| | frame_no = round(target_fps * frame_no /fps) |
| |
|
| | |
| | if frame_no >= total_frames or frame_no < 0: |
| | if return_last_if_missing: |
| | frame_no = total_frames - 1 |
| | else: |
| | cap.release() |
| | raise IndexError(f"Frame {frame_no} out of bounds (0-{total_frames-1})") |
| | |
| | |
| | cap.set(cv2.CAP_PROP_POS_FRAMES, frame_no) |
| | ret, frame = cap.read() |
| | cap.release() |
| | |
| | if not ret: |
| | raise ValueError(f"Failed to read frame {frame_no}") |
| | |
| | |
| | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| | if return_PIL: |
| | return Image.fromarray(frame) |
| | else: |
| | return (torch.from_numpy(frame).permute(2, 0, 1).float() / 127.5) - 1.0 |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | def convert_image_to_video(image): |
| | if image is None: |
| | return None |
| | |
| | |
| | if isinstance(image, np.ndarray): |
| | |
| | img_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) |
| | else: |
| | |
| | img_array = np.array(image) |
| | img_bgr = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR) |
| | |
| | height, width = img_bgr.shape[:2] |
| | |
| | |
| | with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as temp_video: |
| | fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
| | out = cv2.VideoWriter(temp_video.name, fourcc, 30.0, (width, height)) |
| | out.write(img_bgr) |
| | out.release() |
| | return temp_video.name |
| | |
| | def resize_lanczos(img, h, w): |
| | img = (img + 1).float().mul_(127.5) |
| | img = Image.fromarray(np.clip(img.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) |
| | img = img.resize((w,h), resample=Image.Resampling.LANCZOS) |
| | img = torch.from_numpy(np.array(img).astype(np.float32)).movedim(-1, 0) |
| | img = img.div(127.5).sub_(1) |
| | return img |
| |
|
| | def remove_background(img, session=None): |
| | if session ==None: |
| | session = new_session() |
| | img = Image.fromarray(np.clip(255. * img.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) |
| | img = remove(img, session=session, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB') |
| | return torch.from_numpy(np.array(img).astype(np.float32) / 255.0).movedim(-1, 0) |
| |
|
| |
|
| | def convert_image_to_tensor(image): |
| | return torch.from_numpy(np.array(image).astype(np.float32)).div_(127.5).sub_(1.).movedim(-1, 0) |
| |
|
| | def convert_tensor_to_image(t, frame_no = 0, mask_levels = False): |
| | if len(t.shape) == 4: |
| | t = t[:, frame_no] |
| | if t.shape[0]== 1: |
| | t = t.expand(3,-1,-1) |
| | if mask_levels: |
| | return Image.fromarray(t.clone().mul_(255).permute(1,2,0).to(torch.uint8).cpu().numpy()) |
| | else: |
| | return Image.fromarray(t.clone().add_(1.).mul_(127.5).permute(1,2,0).to(torch.uint8).cpu().numpy()) |
| |
|
| | def save_image(tensor_image, name, frame_no = -1): |
| | convert_tensor_to_image(tensor_image, frame_no).save(name) |
| |
|
| | def get_outpainting_full_area_dimensions(frame_height,frame_width, outpainting_dims): |
| | outpainting_top, outpainting_bottom, outpainting_left, outpainting_right= outpainting_dims |
| | frame_height = int(frame_height * (100 + outpainting_top + outpainting_bottom) / 100) |
| | frame_width = int(frame_width * (100 + outpainting_left + outpainting_right) / 100) |
| | return frame_height, frame_width |
| |
|
| | def rgb_bw_to_rgba_mask(img, thresh=127): |
| | arr = np.array(img.convert('L')) |
| | alpha = (arr > thresh).astype(np.uint8) * 255 |
| | rgba = np.dstack([np.full_like(alpha, 255)] * 3 + [alpha]) |
| | return Image.fromarray(rgba, 'RGBA') |
| |
|
| |
|
| | def get_outpainting_frame_location(final_height, final_width, outpainting_dims, block_size = 8): |
| | outpainting_top, outpainting_bottom, outpainting_left, outpainting_right= outpainting_dims |
| | raw_height = int(final_height / ((100 + outpainting_top + outpainting_bottom) / 100)) |
| | height = int(raw_height / block_size) * block_size |
| | extra_height = raw_height - height |
| | |
| | raw_width = int(final_width / ((100 + outpainting_left + outpainting_right) / 100)) |
| | width = int(raw_width / block_size) * block_size |
| | extra_width = raw_width - width |
| | margin_top = int(outpainting_top/(100 + outpainting_top + outpainting_bottom) * final_height) |
| | if extra_height != 0 and (outpainting_top + outpainting_bottom) != 0: |
| | margin_top += int(outpainting_top / (outpainting_top + outpainting_bottom) * extra_height) |
| | if (margin_top + height) > final_height or outpainting_bottom == 0: margin_top = final_height - height |
| | margin_left = int(outpainting_left/(100 + outpainting_left + outpainting_right) * final_width) |
| | if extra_width != 0 and (outpainting_left + outpainting_right) != 0: |
| | margin_left += int(outpainting_left / (outpainting_left + outpainting_right) * extra_height) |
| | if (margin_left + width) > final_width or outpainting_right == 0: margin_left = final_width - width |
| | return height, width, margin_top, margin_left |
| |
|
| | def rescale_and_crop(img, w, h): |
| | ow, oh = img.size |
| | target_ratio = w / h |
| | orig_ratio = ow / oh |
| | |
| | if orig_ratio > target_ratio: |
| | |
| | nw = int(oh * target_ratio) |
| | img = img.crop(((ow - nw) // 2, 0, (ow + nw) // 2, oh)) |
| | else: |
| | |
| | nh = int(ow / target_ratio) |
| | img = img.crop((0, (oh - nh) // 2, ow, (oh + nh) // 2)) |
| | |
| | return img.resize((w, h), Image.LANCZOS) |
| |
|
| | def calculate_new_dimensions(canvas_height, canvas_width, image_height, image_width, fit_into_canvas, block_size = 16): |
| | if fit_into_canvas == None or fit_into_canvas == 2: |
| | |
| | return canvas_height, canvas_width |
| | if fit_into_canvas == 1: |
| | scale1 = min(canvas_height / image_height, canvas_width / image_width) |
| | scale2 = min(canvas_width / image_height, canvas_height / image_width) |
| | scale = max(scale1, scale2) |
| | else: |
| | scale = (canvas_height * canvas_width / (image_height * image_width))**(1/2) |
| |
|
| | new_height = round( image_height * scale / block_size) * block_size |
| | new_width = round( image_width * scale / block_size) * block_size |
| | return new_height, new_width |
| |
|
| | def calculate_dimensions_and_resize_image(image, canvas_height, canvas_width, fit_into_canvas, fit_crop, block_size = 16): |
| | if fit_crop: |
| | image = rescale_and_crop(image, canvas_width, canvas_height) |
| | new_width, new_height = image.size |
| | else: |
| | image_width, image_height = image.size |
| | new_height, new_width = calculate_new_dimensions(canvas_height, canvas_width, image_height, image_width, fit_into_canvas, block_size = block_size ) |
| | image = image.resize((new_width, new_height), resample=Image.Resampling.LANCZOS) |
| | return image, new_height, new_width |
| |
|
| | def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, any_background_ref, fit_into_canvas = 0, block_size= 16, outpainting_dims = None, background_ref_outpainted = True, inpaint_color = 127.5, return_tensor = False, ignore_last_refs = 0 ): |
| | if rm_background: |
| | session = new_session() |
| |
|
| | output_list =[] |
| | output_mask_list =[] |
| | for i, img in enumerate(img_list if ignore_last_refs == 0 else img_list[:-ignore_last_refs]): |
| | width, height = img.size |
| | resized_mask = None |
| | if any_background_ref == 1 and i==0 or any_background_ref == 2: |
| | if outpainting_dims is not None and background_ref_outpainted: |
| | resized_image, resized_mask = fit_image_into_canvas(img, (budget_height, budget_width), inpaint_color, full_frame = True, outpainting_dims = outpainting_dims, return_mask= True, return_image= True) |
| | elif img.size != (budget_width, budget_height): |
| | resized_image= img.resize((budget_width, budget_height), resample=Image.Resampling.LANCZOS) |
| | else: |
| | resized_image =img |
| | elif fit_into_canvas == 1: |
| | white_canvas = np.ones((budget_height, budget_width, 3), dtype=np.uint8) * 255 |
| | scale = min(budget_height / height, budget_width / width) |
| | new_height = int(height * scale) |
| | new_width = int(width * scale) |
| | resized_image= img.resize((new_width,new_height), resample=Image.Resampling.LANCZOS) |
| | top = (budget_height - new_height) // 2 |
| | left = (budget_width - new_width) // 2 |
| | white_canvas[top:top + new_height, left:left + new_width] = np.array(resized_image) |
| | resized_image = Image.fromarray(white_canvas) |
| | else: |
| | scale = (budget_height * budget_width / (height * width))**(1/2) |
| | new_height = int( round(height * scale / block_size) * block_size) |
| | new_width = int( round(width * scale / block_size) * block_size) |
| | resized_image= img.resize((new_width,new_height), resample=Image.Resampling.LANCZOS) |
| | if rm_background and not (any_background_ref and i==0 or any_background_ref == 2) : |
| | |
| | resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB') |
| | if return_tensor: |
| | output_list.append(convert_image_to_tensor(resized_image).unsqueeze(1)) |
| | else: |
| | output_list.append(resized_image) |
| | output_mask_list.append(resized_mask) |
| | if ignore_last_refs: |
| | for img in img_list[-ignore_last_refs:]: |
| | output_list.append(convert_image_to_tensor(img).unsqueeze(1) if return_tensor else img) |
| | output_mask_list.append(None) |
| |
|
| | return output_list, output_mask_list |
| |
|
| | def fit_image_into_canvas(ref_img, image_size, canvas_tf_bg =127.5, device ="cpu", full_frame = False, outpainting_dims = None, return_mask = False, return_image = False): |
| | from shared.utils.utils import save_image |
| | inpaint_color = canvas_tf_bg / 127.5 - 1 |
| |
|
| | ref_width, ref_height = ref_img.size |
| | if (ref_height, ref_width) == image_size and outpainting_dims == None: |
| | ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1) |
| | canvas = torch.zeros_like(ref_img[:1]) if return_mask else None |
| | else: |
| | if outpainting_dims != None: |
| | final_height, final_width = image_size |
| | canvas_height, canvas_width, margin_top, margin_left = get_outpainting_frame_location(final_height, final_width, outpainting_dims, 1) |
| | else: |
| | canvas_height, canvas_width = image_size |
| | if full_frame: |
| | new_height = canvas_height |
| | new_width = canvas_width |
| | top = left = 0 |
| | else: |
| | |
| | |
| | |
| | |
| | scale = min(canvas_height / ref_height, canvas_width / ref_width) |
| | new_height = int(ref_height * scale) |
| | new_width = int(ref_width * scale) |
| | top = (canvas_height - new_height) // 2 |
| | left = (canvas_width - new_width) // 2 |
| | ref_img = ref_img.resize((new_width, new_height), resample=Image.Resampling.LANCZOS) |
| | ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1) |
| | if outpainting_dims != None: |
| | canvas = torch.full((3, 1, final_height, final_width), inpaint_color, dtype= torch.float, device=device) |
| | canvas[:, :, margin_top + top:margin_top + top + new_height, margin_left + left:margin_left + left + new_width] = ref_img |
| | else: |
| | canvas = torch.full((3, 1, canvas_height, canvas_width), inpaint_color, dtype= torch.float, device=device) |
| | canvas[:, :, top:top + new_height, left:left + new_width] = ref_img |
| | ref_img = canvas |
| | canvas = None |
| | if return_mask: |
| | if outpainting_dims != None: |
| | canvas = torch.ones((1, 1, final_height, final_width), dtype= torch.float, device=device) |
| | canvas[:, :, margin_top + top:margin_top + top + new_height, margin_left + left:margin_left + left + new_width] = 0 |
| | else: |
| | canvas = torch.ones((1, 1, canvas_height, canvas_width), dtype= torch.float, device=device) |
| | canvas[:, :, top:top + new_height, left:left + new_width] = 0 |
| | canvas = canvas.to(device) |
| | if return_image: |
| | return convert_tensor_to_image(ref_img), canvas |
| |
|
| | return ref_img.to(device), canvas |
| |
|
| | def prepare_video_guide_and_mask( video_guides, video_masks, pre_video_guide, image_size, current_video_length = 81, latent_size = 4, any_mask = False, any_guide_padding = False, guide_inpaint_color = 127.5, keep_video_guide_frames = [], inject_frames = [], outpainting_dims = None, device ="cpu"): |
| | src_videos, src_masks = [], [] |
| | inpaint_color_compressed = guide_inpaint_color/127.5 - 1 |
| | prepend_count = pre_video_guide.shape[1] if pre_video_guide is not None else 0 |
| | for guide_no, (cur_video_guide, cur_video_mask) in enumerate(zip(video_guides, video_masks)): |
| | src_video, src_mask = cur_video_guide, cur_video_mask |
| | if pre_video_guide is not None: |
| | src_video = pre_video_guide if src_video is None else torch.cat( [pre_video_guide, src_video], dim=1) |
| | if any_mask: |
| | src_mask = torch.zeros_like(pre_video_guide[:1]) if src_mask is None else torch.cat( [torch.zeros_like(pre_video_guide[:1]), src_mask], dim=1) |
| |
|
| | if any_guide_padding: |
| | if src_video is None: |
| | src_video = torch.full( (3, current_video_length, *image_size ), inpaint_color_compressed, dtype = torch.float, device= device) |
| | elif src_video.shape[1] < current_video_length: |
| | src_video = torch.cat([src_video, torch.full( (3, current_video_length - src_video.shape[1], *src_video.shape[-2:] ), inpaint_color_compressed, dtype = src_video.dtype, device= src_video.device) ], dim=1) |
| | elif src_video is not None: |
| | new_num_frames = (src_video.shape[1] - 1) // latent_size * latent_size + 1 |
| | src_video = src_video[:, :new_num_frames] |
| |
|
| | if any_mask and src_video is not None: |
| | if src_mask is None: |
| | src_mask = torch.ones_like(src_video[:1]) |
| | elif src_mask.shape[1] < src_video.shape[1]: |
| | src_mask = torch.cat([src_mask, torch.full( (1, src_video.shape[1]- src_mask.shape[1], *src_mask.shape[-2:] ), 1, dtype = src_video.dtype, device= src_video.device) ], dim=1) |
| | else: |
| | src_mask = src_mask[:, :src_video.shape[1]] |
| |
|
| | if src_video is not None : |
| | for k, keep in enumerate(keep_video_guide_frames): |
| | if not keep: |
| | pos = prepend_count + k |
| | src_video[:, pos:pos+1] = inpaint_color_compressed |
| | if any_mask: src_mask[:, pos:pos+1] = 1 |
| |
|
| | for k, frame in enumerate(inject_frames): |
| | if frame != None: |
| | pos = prepend_count + k |
| | src_video[:, pos:pos+1], msk = fit_image_into_canvas(frame, image_size, guide_inpaint_color, device, True, outpainting_dims, return_mask= any_mask) |
| | if any_mask: src_mask[:, pos:pos+1] = msk |
| | src_videos.append(src_video) |
| | src_masks.append(src_mask) |
| | return src_videos, src_masks |
| |
|
| |
|
| |
|