| | import os |
| |
|
| | |
| | os.environ["CUDA_MODULE_LOADING"] = "LAZY" |
| | |
| | os.environ["SAFETENSORS_FAST_GPU"] = "1" |
| | import cv2 |
| | import torch |
| | import time |
| | import imageio |
| | import numpy as np |
| | from tqdm import tqdm |
| | import moviepy.editor as mp |
| | import torch |
| |
|
| | from audio import load_wav, melspectrogram |
| | from fete_model import FETE_model |
| | from preprocess_videos import face_detect, load_from_npz |
| |
|
| | fps = 25 |
| | mel_idx_multiplier = 80.0 / fps |
| |
|
| | mel_step_size = 16 |
| | batch_size = 64 if torch.cuda.is_available() else 4 |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | print("Using {} for inference.".format(device)) |
| | use_fp16 = True if torch.cuda.is_available() else False |
| | print("Using FP16 for inference.") if use_fp16 else None |
| | torch.backends.cudnn.benchmark = True if device == "cuda" else False |
| |
|
| |
|
| | def init_model(): |
| | checkpoint_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "checkpoints/obama-fp16.safetensors") |
| | model = FETE_model() |
| | if checkpoint_path.endswith(".pth") or checkpoint_path.endswith(".ckpt"): |
| | if device == "cuda": |
| | checkpoint = torch.load(checkpoint_path) |
| | else: |
| | checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) |
| | s = checkpoint["state_dict"] |
| | else: |
| | from safetensors import safe_open |
| |
|
| | s = {} |
| | with safe_open(checkpoint_path, framework="pt", device=device) as f: |
| | for key in f.keys(): |
| | s[key] = f.get_tensor(key) |
| | new_s = {} |
| | for k, v in s.items(): |
| | new_s[k.replace("module.", "")] = v |
| | model.load_state_dict(new_s) |
| |
|
| | model = model.to(device) |
| | model.eval() |
| | print("Model loaded") |
| | if use_fp16: |
| | for name, module in model.named_modules(): |
| | if ".query_conv" in name or ".key_conv" in name or ".value_conv" in name: |
| | |
| | module.to(torch.float) |
| | else: |
| | module.to(torch.half) |
| | print("Model converted to half precision to accelerate inference") |
| | return model |
| |
|
| |
|
| | def make_mask(image_size=256, border_size=32): |
| | mask_bar = np.linspace(1, 0, border_size).reshape(1, -1).repeat(image_size, axis=0) |
| | mask = np.zeros((image_size, image_size), dtype=np.float32) |
| | mask[-border_size:, :] += mask_bar.T[::-1] |
| | mask[:, :border_size] = mask_bar |
| | mask[:, -border_size:] = mask_bar[:, ::-1] |
| | mask[-border_size:, :][mask[-border_size:, :] < 0.6] = 0.6 |
| | mask = np.stack([mask] * 3, axis=-1).astype(np.float32) |
| | return mask |
| |
|
| |
|
| | face_mask = make_mask() |
| |
|
| |
|
| | def blend_images(foreground, background): |
| | |
| | temp_mask = cv2.resize(face_mask, (foreground.shape[1], foreground.shape[0])) |
| | blended = cv2.multiply(foreground.astype(np.float32), temp_mask) |
| | blended += cv2.multiply(background.astype(np.float32), 1 - temp_mask) |
| | blended = np.clip(blended, 0, 255).astype(np.uint8) |
| | return blended |
| |
|
| |
|
| | def smooth_coord(last_coord, current_coord, factor=0.4): |
| | change = np.array(current_coord) - np.array(last_coord) |
| | change = change * factor |
| | return (np.array(last_coord) + np.array(change)).astype(int).tolist() |
| |
|
| |
|
| | def add_black(imgs): |
| | for i in range(len(imgs)): |
| | |
| | imgs[i] = cv2.vconcat( |
| | [np.zeros((100, imgs[i].shape[1], 3), dtype=np.uint8), imgs[i], np.zeros((20, imgs[i].shape[1], 3), dtype=np.uint8)] |
| | ) |
| | |
| |
|
| | |
| | return imgs |
| |
|
| |
|
| | def remove_black(img): |
| | return img[100:-20] |
| |
|
| |
|
| | def resize_length(input_attributes, length): |
| | input_attributes = np.array(input_attributes) |
| | resized_attributes = [input_attributes[int(i_ * (input_attributes.shape[0] / length))] for i_ in range(length)] |
| | return np.array(resized_attributes).T |
| |
|
| |
|
| | def output_chunks(input_attributes): |
| | output_chunks = [] |
| | len_ = len(input_attributes[0]) |
| |
|
| | i = 0 |
| | |
| | |
| | while 1: |
| | start_idx = int(i * mel_idx_multiplier) |
| | if start_idx + mel_step_size > len_: |
| | output_chunks.append(input_attributes[:, len_ - mel_step_size :]) |
| | break |
| | output_chunks.append(input_attributes[:, start_idx : start_idx + mel_step_size]) |
| | i += 1 |
| | return output_chunks |
| |
|
| |
|
| | def prepare_data(face_path, audio_path, pose, emotion, blink, img_size=256, pads=[0, 0, 0, 0]): |
| | if os.path.isfile(face_path) and face_path.split(".")[1] in ["jpg", "png", "jpeg"]: |
| | static = True |
| | full_frames = [cv2.imread(face_path)] |
| | else: |
| | static = False |
| | video_stream = cv2.VideoCapture(face_path) |
| |
|
| | |
| | full_frames = [] |
| | while 1: |
| | still_reading, frame = video_stream.read() |
| | if not still_reading: |
| | video_stream.release() |
| | break |
| | full_frames.append(frame) |
| | print("Number of frames available for inference: " + str(len(full_frames))) |
| |
|
| | wav = load_wav(audio_path, 16000) |
| | mel = melspectrogram(wav) |
| | |
| | len_ = mel.shape[1] |
| | mel = mel[:, :len_] |
| | |
| |
|
| | pose = resize_length(pose, len_) |
| | emotion = resize_length(emotion, len_) |
| | blink = resize_length(blink, len_) |
| |
|
| | if np.isnan(mel.reshape(-1)).sum() > 0: |
| | raise ValueError("Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again") |
| |
|
| | mel_chunks = output_chunks(mel) |
| | pose_chunks = output_chunks(pose) |
| | emotion_chunks = output_chunks(emotion) |
| | blink_chunks = output_chunks(blink) |
| |
|
| | gen = datagen(face_path, full_frames, mel_chunks, pose_chunks, emotion_chunks, blink_chunks, static=static, img_size=img_size, pads=pads) |
| | steps = int(np.ceil(float(len(mel_chunks)) / batch_size)) |
| |
|
| | return gen, steps |
| |
|
| |
|
| | def preprocess_batch(batch): |
| | return torch.FloatTensor(np.reshape(batch, [len(batch), 1, batch[0].shape[0], batch[0].shape[1]])).to(device) |
| |
|
| |
|
| | def datagen(face_path, frames, mels, poses, emotions, blinks, static=False, img_size=256, pads=[0, 0, 0, 0]): |
| | img_batch, mel_batch, pose_batch, emotion_batch, blink_batch, frame_batch, coords_batch = [], [], [], [], [], [], [] |
| | scale_factor = img_size // 128 |
| |
|
| | |
| | frames = frames[: len(mels)] |
| | frames = add_black(frames) |
| | try: |
| | video_name = os.path.basename(face_path).split(".")[0] |
| | coords = load_from_npz(video_name) |
| | face_det_results = [[image[y1:y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(frames, coords)] |
| |
|
| | except Exception as e: |
| | print("No existing coords found, running face detection...", "Error: ", e) |
| | if not static: |
| | coords = face_detect(frames, pads) |
| | face_det_results = [[image[y1:y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(frames, coords)] |
| | else: |
| | coords = face_detect([frames[0]], pads) |
| | face_det_results = [[image[y1:y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(frames, coords)] |
| |
|
| | face_det_results = face_det_results[: len(mels)] |
| |
|
| | while len(frames) < len(mels): |
| | face_det_results = face_det_results + face_det_results[::-1] |
| | frames = frames + frames[::-1] |
| | else: |
| | face_det_results = face_det_results[: len(mels)] |
| | frames = frames[: len(mels)] |
| |
|
| | for i in range(len(mels)): |
| | idx = 0 if static else i % len(frames) |
| | frame_to_save = frames[idx].copy() |
| | face, coords = face_det_results[idx].copy() |
| | face = cv2.resize(face, (img_size, img_size)) |
| |
|
| | img_batch.append(face) |
| | mel_batch.append(mels[i]) |
| | pose_batch.append(poses[i]) |
| | emotion_batch.append(emotions[i]) |
| | blink_batch.append(blinks[i]) |
| | frame_batch.append(frame_to_save) |
| | coords_batch.append(coords) |
| |
|
| | |
| | |
| | if len(img_batch) >= batch_size: |
| | img_masked = np.asarray(img_batch).copy() |
| |
|
| | img_masked[:, 16 * scale_factor : -16 * scale_factor, 16 * scale_factor : -16 * scale_factor] = 0.0 |
| |
|
| | img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.0 |
| | img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device) |
| |
|
| | mel_batch = preprocess_batch(mel_batch) |
| | pose_batch = preprocess_batch(pose_batch) |
| | emotion_batch = preprocess_batch(emotion_batch) |
| | blink_batch = preprocess_batch(blink_batch) |
| |
|
| | if use_fp16: |
| | yield ( |
| | img_batch.half(), |
| | mel_batch.half(), |
| | pose_batch.half(), |
| | emotion_batch.half(), |
| | blink_batch.half(), |
| | ), frame_batch, coords_batch |
| | else: |
| | yield (img_batch, mel_batch, pose_batch, emotion_batch, blink_batch), frame_batch, coords_batch |
| | img_batch, mel_batch, pose_batch, emotion_batch, blink_batch, frame_batch, coords_batch = [], [], [], [], [], [], [] |
| |
|
| | if len(img_batch) > 0: |
| | img_masked = np.asarray(img_batch).copy() |
| |
|
| | img_masked[:, 16 * scale_factor : -16 * scale_factor, 16 * scale_factor : -16 * scale_factor] = 0.0 |
| |
|
| | img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.0 |
| | img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device) |
| |
|
| | mel_batch = preprocess_batch(mel_batch) |
| | pose_batch = preprocess_batch(pose_batch) |
| | emotion_batch = preprocess_batch(emotion_batch) |
| | blink_batch = preprocess_batch(blink_batch) |
| |
|
| | if use_fp16: |
| | yield (img_batch.half(), mel_batch.half(), pose_batch.half(), emotion_batch.half(), blink_batch.half()), frame_batch, coords_batch |
| | else: |
| | yield (img_batch, mel_batch, pose_batch, emotion_batch, blink_batch), frame_batch, coords_batch |
| |
|
| |
|
| | def infenrece(model, face_path, audio_path, pose, emotion, blink, preview=False): |
| | timestamp = time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime(time.time())) |
| | gen, steps = prepare_data(face_path, audio_path, pose, emotion, blink) |
| | steps = 1 if preview else steps |
| | |
| |
|
| | if preview: |
| | outfile = "/tmp/{}.jpg".format(timestamp) |
| | else: |
| | outfile = "/tmp/{}.mp4".format(timestamp) |
| | tmp_video = "/tmp/temp_{}.mp4".format(timestamp) |
| | writer = ( |
| | imageio.get_writer(tmp_video, fps=fps, codec="libx264", quality=10, pixelformat="yuv420p", macro_block_size=1) |
| | if not preview |
| | else None |
| | ) |
| | |
| | for inputs, frames, coords in tqdm(gen, total=steps): |
| | with torch.no_grad(): |
| | pred = model(*inputs) |
| |
|
| | pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.0 |
| |
|
| | for p, f, c in zip(pred, frames, coords): |
| | y1, y2, x1, x2 = c |
| | y1, y2, x1, x2 = int(y1), int(y2), int(x1), int(x2) |
| | y = round(y2 - y1) |
| | x = round(x2 - x1) |
| | p = cv2.resize(p.astype(np.uint8), (x, y)) |
| |
|
| | try: |
| | f[y1 : y1 + y, x1 : x1 + x] = blend_images(f[y1 : y1 + y, x1 : x1 + x], p) |
| | except Exception as e: |
| | print(e) |
| | f[y1 : y1 + y, x1 : x1 + x] = p |
| | f = remove_black(f) |
| | if preview: |
| | cv2.imwrite(outfile, f, [int(cv2.IMWRITE_JPEG_QUALITY), 95]) |
| | return outfile |
| | writer.append_data(cv2.cvtColor(f, cv2.COLOR_BGR2RGB)) |
| | writer.close() |
| | video_clip = mp.VideoFileClip(tmp_video) |
| | audio_clip = mp.AudioFileClip(audio_path) |
| | video_clip = video_clip.set_audio(audio_clip) |
| | video_clip.write_videofile(outfile, codec="libx264") |
| |
|
| | print("Saved to {}".format(outfile) if os.path.exists(outfile) else "Failed to save {}".format(outfile)) |
| | try: |
| | os.remove(tmp_video) |
| | del video_clip |
| | del audio_clip |
| | del gen |
| | except: |
| | pass |
| | return outfile |
| |
|
| |
|
| | if __name__ == "__main__": |
| | model = init_model() |
| |
|
| | from attributtes_utils import input_pose, input_emotion, input_blink |
| |
|
| | pose = input_pose() |
| | emotion = input_emotion() |
| | blink = input_blink() |
| | audio_path = "./assets/sample.wav" |
| | face_path = "./assets/sample.mp4" |
| |
|
| | infenrece(model, face_path, audio_path, pose, emotion, blink) |
| |
|