| import os |
| import argparse |
| import subprocess |
| import torch |
| import numpy as np |
| from tqdm import tqdm |
| from omegaconf import OmegaConf |
| from typing import Tuple, List, Union |
| import decord |
| import json |
| import cv2 |
| from musetalk.utils.face_detection import FaceAlignment,LandmarksType |
| from mmpose.apis import inference_topdown, init_model |
| from mmpose.structures import merge_data_samples |
| import sys |
|
|
| def fast_check_ffmpeg(): |
| try: |
| subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True) |
| return True |
| except: |
| return False |
|
|
| ffmpeg_path = "./ffmpeg-4.4-amd64-static/" |
| if not fast_check_ffmpeg(): |
| print("Adding ffmpeg to PATH") |
| |
| path_separator = ';' if sys.platform == 'win32' else ':' |
| os.environ["PATH"] = f"{args.ffmpeg_path}{path_separator}{os.environ['PATH']}" |
| if not fast_check_ffmpeg(): |
| print("Warning: Unable to find ffmpeg, please ensure ffmpeg is properly installed") |
|
|
| class AnalyzeFace: |
| def __init__(self, device: Union[str, torch.device], config_file: str, checkpoint_file: str): |
| """ |
| Initialize the AnalyzeFace class with the given device, config file, and checkpoint file. |
| |
| Parameters: |
| device (Union[str, torch.device]): The device to run the models on ('cuda' or 'cpu'). |
| config_file (str): Path to the mmpose model configuration file. |
| checkpoint_file (str): Path to the mmpose model checkpoint file. |
| """ |
| self.device = device |
| self.dwpose = init_model(config_file, checkpoint_file, device=self.device) |
| self.facedet = FaceAlignment(LandmarksType._2D, flip_input=False, device=self.device) |
|
|
| def __call__(self, im: np.ndarray) -> Tuple[List[np.ndarray], np.ndarray]: |
| """ |
| Detect faces and keypoints in the given image. |
| |
| Parameters: |
| im (np.ndarray): The input image. |
| maxface (bool): Whether to detect the maximum face. Default is True. |
| |
| Returns: |
| Tuple[List[np.ndarray], np.ndarray]: A tuple containing the bounding boxes and keypoints. |
| """ |
| try: |
| |
| if im.ndim == 3: |
| im = np.expand_dims(im, axis=0) |
| elif im.ndim != 4 or im.shape[0] != 1: |
| raise ValueError("Input image must have shape (1, H, W, C)") |
| |
| bbox = self.facedet.get_detections_for_batch(np.asarray(im)) |
| results = inference_topdown(self.dwpose, np.asarray(im)[0]) |
| results = merge_data_samples(results) |
| keypoints = results.pred_instances.keypoints |
| face_land_mark= keypoints[0][23:91] |
| face_land_mark = face_land_mark.astype(np.int32) |
|
|
| return face_land_mark, bbox |
| |
| except Exception as e: |
| print(f"Error during face analysis: {e}") |
| return np.array([]),[] |
|
|
| def convert_video(org_path: str, dst_path: str, vid_list: List[str]) -> None: |
|
|
| """ |
| Convert video files to a specified format and save them to the destination path. |
| |
| Parameters: |
| org_path (str): The directory containing the original video files. |
| dst_path (str): The directory where the converted video files will be saved. |
| vid_list (List[str]): A list of video file names to process. |
| |
| Returns: |
| None |
| """ |
| for idx, vid in enumerate(vid_list): |
| if vid.endswith('.mp4'): |
| org_vid_path = os.path.join(org_path, vid) |
| dst_vid_path = os.path.join(dst_path, vid) |
| |
| if org_vid_path != dst_vid_path: |
| cmd = [ |
| "ffmpeg", "-hide_banner", "-y", "-i", org_vid_path, |
| "-r", "25", "-crf", "15", "-c:v", "libx264", |
| "-pix_fmt", "yuv420p", dst_vid_path |
| ] |
| subprocess.run(cmd, check=True) |
|
|
| if idx % 1000 == 0: |
| print(f"### {idx} videos converted ###") |
|
|
| def segment_video(org_path: str, dst_path: str, vid_list: List[str], segment_duration: int = 30) -> None: |
| """ |
| Segment video files into smaller clips of specified duration. |
| |
| Parameters: |
| org_path (str): The directory containing the original video files. |
| dst_path (str): The directory where the segmented video files will be saved. |
| vid_list (List[str]): A list of video file names to process. |
| segment_duration (int): The duration of each segment in seconds. Default is 30 seconds. |
| |
| Returns: |
| None |
| """ |
| for idx, vid in enumerate(vid_list): |
| if vid.endswith('.mp4'): |
| input_file = os.path.join(org_path, vid) |
| original_filename = os.path.basename(input_file) |
|
|
| command = [ |
| 'ffmpeg', '-i', input_file, '-c', 'copy', '-map', '0', |
| '-segment_time', str(segment_duration), '-f', 'segment', |
| '-reset_timestamps', '1', |
| os.path.join(dst_path, f'clip%03d_{original_filename}') |
| ] |
|
|
| subprocess.run(command, check=True) |
|
|
| def extract_audio(org_path: str, dst_path: str, vid_list: List[str]) -> None: |
| """ |
| Extract audio from video files and save as WAV format. |
| |
| Parameters: |
| org_path (str): The directory containing the original video files. |
| dst_path (str): The directory where the extracted audio files will be saved. |
| vid_list (List[str]): A list of video file names to process. |
| |
| Returns: |
| None |
| """ |
| for idx, vid in enumerate(vid_list): |
| if vid.endswith('.mp4'): |
| video_path = os.path.join(org_path, vid) |
| audio_output_path = os.path.join(dst_path, os.path.splitext(vid)[0] + ".wav") |
| try: |
| command = [ |
| 'ffmpeg', '-hide_banner', '-y', '-i', video_path, |
| '-vn', '-acodec', 'pcm_s16le', '-f', 'wav', |
| '-ar', '16000', '-ac', '1', audio_output_path, |
| ] |
| |
| subprocess.run(command, check=True) |
| print(f"Audio saved to: {audio_output_path}") |
| except subprocess.CalledProcessError as e: |
| print(f"Error extracting audio from {vid}: {e}") |
|
|
| def split_data(video_files: List[str], val_list_hdtf: List[str]) -> (List[str], List[str]): |
| """ |
| Split video files into training and validation sets based on val_list_hdtf. |
| |
| Parameters: |
| video_files (List[str]): A list of video file names. |
| val_list_hdtf (List[str]): A list of validation file identifiers. |
| |
| Returns: |
| (List[str], List[str]): A tuple containing the training and validation file lists. |
| """ |
| val_files = [f for f in video_files if any(val_id in f for val_id in val_list_hdtf)] |
| train_files = [f for f in video_files if f not in val_files] |
| return train_files, val_files |
|
|
| def save_list_to_file(file_path: str, data_list: List[str]) -> None: |
| """ |
| Save a list of strings to a file, each string on a new line. |
| |
| Parameters: |
| file_path (str): The path to the file where the list will be saved. |
| data_list (List[str]): The list of strings to save. |
| |
| Returns: |
| None |
| """ |
| with open(file_path, 'w') as file: |
| for item in data_list: |
| file.write(f"{item}\n") |
|
|
| def generate_train_list(cfg): |
| train_file_path = cfg.video_clip_file_list_train |
| val_file_path = cfg.video_clip_file_list_val |
| val_list_hdtf = cfg.val_list_hdtf |
|
|
| meta_list = os.listdir(cfg.meta_root) |
|
|
| sorted_meta_list = sorted(meta_list) |
| train_files, val_files = split_data(meta_list, val_list_hdtf) |
|
|
| save_list_to_file(train_file_path, train_files) |
| save_list_to_file(val_file_path, val_files) |
|
|
| print(val_list_hdtf) |
|
|
| def analyze_video(org_path: str, dst_path: str, vid_list: List[str]) -> None: |
| """ |
| Convert video files to a specified format and save them to the destination path. |
| |
| Parameters: |
| org_path (str): The directory containing the original video files. |
| dst_path (str): The directory where the meta json will be saved. |
| vid_list (List[str]): A list of video file names to process. |
| |
| Returns: |
| None |
| """ |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| config_file = './musetalk/utils/dwpose/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py' |
| checkpoint_file = './models/dwpose/dw-ll_ucoco_384.pth' |
|
|
| analyze_face = AnalyzeFace(device, config_file, checkpoint_file) |
| |
| for vid in tqdm(vid_list, desc="Processing videos"): |
| |
| |
| if vid.endswith('.mp4'): |
| vid_path = os.path.join(org_path, vid) |
| wav_path = vid_path.replace(".mp4",".wav") |
| vid_meta = os.path.join(dst_path, os.path.splitext(vid)[0] + ".json") |
| if os.path.exists(vid_meta): |
| continue |
| print('process video {}'.format(vid)) |
|
|
| total_bbox_list = [] |
| total_pts_list = [] |
| isvalid = True |
|
|
| |
| try: |
| cap = decord.VideoReader(vid_path, fault_tol=1) |
| except Exception as e: |
| print(e) |
| continue |
|
|
| total_frames = len(cap) |
| for frame_idx in range(total_frames): |
| frame = cap[frame_idx] |
| if frame_idx==0: |
| video_height,video_width,_ = frame.shape |
| frame_bgr = cv2.cvtColor(frame.asnumpy(), cv2.COLOR_BGR2RGB) |
| pts_list, bbox_list = analyze_face(frame_bgr) |
|
|
| if len(bbox_list)>0 and None not in bbox_list: |
| bbox = bbox_list[0] |
| else: |
| isvalid = False |
| bbox = [] |
| print(f"set isvalid to False as broken img in {frame_idx} of {vid}") |
| break |
|
|
| |
| if len(pts_list)>0 and pts_list is not None: |
| pts = pts_list.tolist() |
| else: |
| isvalid = False |
| pts = [] |
| break |
|
|
| if frame_idx==0: |
| x1,y1,x2,y2 = bbox |
| face_height, face_width = y2-y1,x2-x1 |
|
|
| total_pts_list.append(pts) |
| total_bbox_list.append(bbox) |
|
|
| meta_data = { |
| "mp4_path": vid_path, |
| "wav_path": wav_path, |
| "video_size": [video_height, video_width], |
| "face_size": [face_height, face_width], |
| "frames": total_frames, |
| "face_list": total_bbox_list, |
| "landmark_list": total_pts_list, |
| "isvalid":isvalid, |
| } |
| with open(vid_meta, 'w') as f: |
| json.dump(meta_data, f, indent=4) |
| |
|
|
|
|
| def main(cfg): |
| |
| os.makedirs(cfg.video_root_25fps, exist_ok=True) |
| os.makedirs(cfg.video_audio_clip_root, exist_ok=True) |
| os.makedirs(cfg.meta_root, exist_ok=True) |
| os.makedirs(os.path.dirname(cfg.video_file_list), exist_ok=True) |
| os.makedirs(os.path.dirname(cfg.video_clip_file_list_train), exist_ok=True) |
| os.makedirs(os.path.dirname(cfg.video_clip_file_list_val), exist_ok=True) |
|
|
| vid_list = os.listdir(cfg.video_root_raw) |
| sorted_vid_list = sorted(vid_list) |
| |
| |
| with open(cfg.video_file_list, 'w') as file: |
| for vid in sorted_vid_list: |
| file.write(vid + '\n') |
|
|
| |
| convert_video(cfg.video_root_raw, cfg.video_root_25fps, sorted_vid_list) |
| |
| |
| segment_video(cfg.video_root_25fps, cfg.video_audio_clip_root, vid_list, segment_duration=cfg.clip_len_second) |
| |
| |
| clip_vid_list = os.listdir(cfg.video_audio_clip_root) |
| extract_audio(cfg.video_audio_clip_root, cfg.video_audio_clip_root, clip_vid_list) |
| |
| |
| analyze_video(cfg.video_audio_clip_root, cfg.meta_root, clip_vid_list) |
| |
| |
| generate_train_list(cfg) |
| print("done") |
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--config", type=str, default="./configs/training/preprocess.yaml") |
| args = parser.parse_args() |
| config = OmegaConf.load(args.config) |
|
|
| main(config) |
| |
|
|