| import torch |
| import cv2 |
| import numpy as np |
| from torchvision.transforms.functional import normalize |
| from tqdm import tqdm |
| from PIL import Image, ImageOps |
| import random |
| import os |
| import requests |
| from insightface.app import FaceAnalysis |
| from facexlib.parsing import init_parsing_model |
| from typing import Union, Optional, Tuple, List |
|
|
| |
| def tensor_to_cv2_img(tensor_frame: torch.Tensor) -> np.ndarray: |
| """Converts a single RGB torch tensor to a BGR OpenCV image.""" |
| img_np = (tensor_frame.cpu().numpy() * 255).astype(np.uint8) |
| return cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) |
|
|
| def tensor_to_cv2_bgra_img(tensor_frame: torch.Tensor) -> np.ndarray: |
| """Converts a single RGBA torch tensor to a BGRA OpenCV image.""" |
| if tensor_frame.shape[2] != 4: |
| raise ValueError("Input tensor must be an RGBA image with 4 channels.") |
| img_np = (tensor_frame.cpu().numpy() * 255).astype(np.uint8) |
| return cv2.cvtColor(img_np, cv2.COLOR_RGBA2BGRA) |
|
|
| def pil_to_tensor(image: Image.Image) -> torch.Tensor: |
| """Converts a PIL image to a torch tensor.""" |
| return torch.from_numpy(np.array(image).astype(np.float32) / 255.0) |
|
|
| class VideoMaskGenerator: |
| def __init__(self, antelopv2_path=".", device: Optional[torch.device] = None): |
| if device is None: |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| else: |
| self.device = device |
|
|
| print(f"Using device: {self.device}") |
|
|
| providers = ["CUDAExecutionProvider"] if self.device.type == "cuda" else ["CPUExecutionProvider"] |
| |
| |
| self.detection_model = FaceAnalysis(name="antelopev2", root=antelopv2_path, providers=providers) |
| self.detection_model.prepare(ctx_id=0, det_size=(640, 640)) |
|
|
| |
| self.parsing_model = init_parsing_model(model_name="bisenet", device=self.device) |
| self.parsing_model.eval() |
| |
| print("FaceProcessor initialized successfully.") |
|
|
| def process( |
| self, |
| video_path: str, |
| face_image: Union[str, Image.Image], |
| confidence_threshold: float = 0.5, |
| face_crop_scale: float = 1.5, |
| dilation_kernel_size: int = 10, |
| feather_amount: int = 21, |
| random_horizontal_flip_chance: float = 0.0, |
| match_angle_and_size: bool = True |
| ) -> Tuple[np.ndarray, np.ndarray, int, int, int]: |
| """ |
| Processes a video to replace a face with a provided face image. |
| |
| Args: |
| video_path (str): Path to the input video file. |
| face_image (Union[str, Image.Image]): Path or PIL image of the face to paste. |
| confidence_threshold (float): Confidence threshold for face detection. |
| face_crop_scale (float): Scale factor for cropping the detected face box. |
| dilation_kernel_size (int): Kernel size for mask dilation. |
| feather_amount (int): Amount of feathering for the mask edges. |
| random_horizontal_flip_chance (float): Chance to flip the source face horizontally. |
| match_angle_and_size (bool): Whether to use landmark matching for rotation and scale. |
| |
| Returns: |
| Tuple[np.ndarray, np.ndarray, int, int, int]: |
| - Processed video as a numpy array (F, H, W, C). |
| - Generated masks as a numpy array (F, H, W). |
| - Width of the processed video. |
| - Height of the processed video. |
| - Number of frames in the processed video. |
| """ |
| |
| if not os.path.exists(video_path): |
| raise FileNotFoundError(f"Video file not found at: {video_path}") |
|
|
| cap = cv2.VideoCapture(video_path) |
| frames = [] |
| while cap.isOpened(): |
| ret, frame = cap.read() |
| if not ret: |
| break |
| frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) |
| cap.release() |
| |
| if not frames: |
| raise ValueError("Could not read any frames from the video.") |
|
|
| video_np = np.array(frames) |
| |
| h, w = video_np.shape[1], video_np.shape[2] |
| new_h, new_w = (h // 16) * 16, (w // 16) * 16 |
| |
| y_start = (h - new_h) // 2 |
| x_start = (w - new_w) // 2 |
| video_cropped = video_np[:, y_start:y_start+new_h, x_start:x_start+new_w, :] |
|
|
| num_frames = video_cropped.shape[0] |
| target_frames = (num_frames // 4) * 4 + 1 |
| video_trimmed = video_cropped[:target_frames] |
|
|
| final_h, final_w, final_frames = video_trimmed.shape[1], video_trimmed.shape[2], video_trimmed.shape[0] |
| print(f"Video pre-processed: {final_w}x{final_h}, {final_frames} frames.") |
|
|
| |
| if isinstance(face_image, str): |
| if face_image.startswith("http"): |
| face_image = Image.open(requests.get(face_image, stream=True, timeout=10).raw) |
| else: |
| face_image = Image.open(face_image) |
| |
| face_image = ImageOps.exif_transpose(face_image).convert("RGBA") |
| face_rgba_tensor = pil_to_tensor(face_image) |
| face_to_paste_cv2 = tensor_to_cv2_bgra_img(face_rgba_tensor) |
|
|
| source_kpts = None |
| if match_angle_and_size: |
| |
| source_face_bgr = cv2.cvtColor(face_to_paste_cv2, cv2.COLOR_BGRA2BGR) |
| source_faces = self.detection_model.get(source_face_bgr) |
| if source_faces: |
| |
| source_kpts = source_faces[0].kps |
| else: |
| print("[Warning] No face or landmarks found in source image. Disabling angle matching.") |
| match_angle_and_size = False |
| |
| face_to_paste_pil = Image.fromarray((face_rgba_tensor.cpu().numpy() * 255).astype(np.uint8), 'RGBA') |
|
|
| |
| processed_frames_list = [] |
| mask_list = [] |
|
|
| for i in tqdm(range(final_frames), desc="Pasting face onto frames"): |
| frame_rgb = video_trimmed[i] |
| frame_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR) |
| |
| |
| faces = self.detection_model.get(frame_bgr) |
| |
| pasted = False |
| final_mask = np.zeros((final_h, final_w), dtype=np.uint8) |
|
|
| if faces: |
| largest_face = max(faces, key=lambda f: (f.bbox[2] - f.bbox[0]) * (f.bbox[3] - f.bbox[1])) |
| |
| if largest_face.det_score > confidence_threshold: |
| |
| if match_angle_and_size and source_kpts is not None: |
| target_kpts = largest_face.kps |
| |
| |
| M, _ = cv2.estimateAffinePartial2D(source_kpts, target_kpts, method=cv2.LMEDS) |
| |
| if M is not None: |
| |
| b, g, r, a = cv2.split(face_to_paste_cv2) |
| source_rgb_cv2 = cv2.merge([r, g, b]) |
| |
| |
| warped_face = cv2.warpAffine(source_rgb_cv2, M, (final_w, final_h)) |
| warped_alpha = cv2.warpAffine(a, M, (final_w, final_h)) |
| |
| |
| alpha_float = warped_alpha.astype(np.float32) / 255.0 |
| alpha_expanded = np.expand_dims(alpha_float, axis=2) |
| |
| frame_rgb = (1.0 - alpha_expanded) * frame_rgb + alpha_expanded * warped_face |
| frame_rgb = frame_rgb.astype(np.uint8) |
| final_mask = warped_alpha |
| pasted = True |
|
|
| |
| if not pasted: |
| x1, y1, x2, y2 = map(int, largest_face.bbox) |
| center_x, center_y = (x1 + x2) // 2, (y1 + y2) // 2 |
| side_len = int(max(x2 - x1, y2 - y1) * face_crop_scale) |
| half_side = side_len // 2 |
| |
| crop_y1, crop_x1 = max(center_y - half_side, 0), max(center_x - half_side, 0) |
| crop_y2, crop_x2 = min(center_y + half_side, final_h), min(center_x + half_side, final_w) |
| |
| box_w, box_h = crop_x2 - crop_x1, crop_y2 - crop_y1 |
|
|
| if box_w > 0 and box_h > 0: |
| source_img = face_to_paste_pil.copy() |
| if random.random() < random_horizontal_flip_chance: |
| source_img = source_img.transpose(Image.FLIP_LEFT_RIGHT) |
| |
| face_resized = source_img.resize((box_w, box_h), Image.Resampling.LANCZOS) |
| |
| target_frame_pil = Image.fromarray(frame_rgb) |
| |
| |
| face_crop_bgr = cv2.cvtColor(frame_rgb[crop_y1:crop_y2, crop_x1:crop_x2], cv2.COLOR_RGB2BGR) |
| if face_crop_bgr.size > 0: |
| face_resized_512 = cv2.resize(face_crop_bgr, (512, 512), interpolation=cv2.INTER_AREA) |
| face_rgb_512 = cv2.cvtColor(face_resized_512, cv2.COLOR_BGR2RGB) |
| face_tensor_in = torch.from_numpy(face_rgb_512.astype(np.float32) / 255.0).permute(2, 0, 1).unsqueeze(0).to(self.device) |
| |
| with torch.no_grad(): |
| normalized_face = normalize(face_tensor_in, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
| parsing_map = self.parsing_model(normalized_face)[0].argmax(dim=1, keepdim=True) |
| |
| parsing_map_np = parsing_map.squeeze().cpu().numpy().astype(np.uint8) |
| parts_to_include = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13] |
| final_mask_512 = np.isin(parsing_map_np, parts_to_include).astype(np.uint8) * 255 |
| |
| if dilation_kernel_size > 0: |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (dilation_kernel_size, dilation_kernel_size)) |
| final_mask_512 = cv2.dilate(final_mask_512, kernel, iterations=1) |
| |
| if feather_amount > 0: |
| if feather_amount % 2 == 0: feather_amount += 1 |
| final_mask_512 = cv2.GaussianBlur(final_mask_512, (feather_amount, feather_amount), 0) |
| |
| mask_resized_to_crop = cv2.resize(final_mask_512, (box_w, box_h), interpolation=cv2.INTER_LINEAR) |
| generated_mask_pil = Image.fromarray(mask_resized_to_crop, mode='L') |
| |
| target_frame_pil.paste(face_resized, (crop_x1, crop_y1), mask=generated_mask_pil) |
| frame_rgb = np.array(target_frame_pil) |
| final_mask[crop_y1:crop_y2, crop_x1:crop_x2] = mask_resized_to_crop |
|
|
| processed_frames_list.append(frame_rgb) |
| mask_list.append(final_mask) |
|
|
| output_video = np.stack(processed_frames_list) |
| |
| output_masks = np.stack(mask_list)[..., np.newaxis] |
| |
| return (output_video, output_masks, final_w, final_h, final_frames) |