Fun-CineForge-Demo / funcineforge /models /inference_model.py
xuan3986's picture
Upload 111 files
03022ee verified
import torch
import torch.nn as nn
import logging
import numpy as np
import os
import torchaudio
import time
import shutil
from funcineforge.utils.set_all_random_seed import set_all_random_seed
from moviepy.video.io.VideoFileClip import VideoFileClip, AudioFileClip
class FunCineForgeInferModel(nn.Module):
def __init__(
self,
lm_model,
fm_model,
voc_model,
**kwargs
):
from funcineforge.auto.auto_model import AutoModel
super().__init__()
self.tokenizer = lm_model.kwargs["tokenizer"]
self.frontend = fm_model.kwargs["frontend"]
self.lm_model = lm_model.model
self.fm_model = fm_model.model
self.voc_model = voc_model.model
mel_extractor = self.fm_model.mel_extractor
if mel_extractor:
self.mel_frame_rate = mel_extractor.sampling_rate // mel_extractor.hop_length
self.sample_rate = mel_extractor.sampling_rate
else:
self.mel_frame_rate = self.fm_model.sample_rate // 480
self.sample_rate = self.fm_model.sample_rate
@torch.no_grad()
def inference(
self,
data_in,
data_lengths=None,
key: list = None,
**kwargs,
):
uttid = key[0]
logging.info(f"generating {uttid}")
# text -> codec in [1, T]
kwargs["tokenizer"] = self.tokenizer
set_all_random_seed(kwargs.get("random_seed", 0))
lm_time = time.time()
codec, hit_eos, states = self.lm_model.inference(data_in, data_lengths, key, **kwargs)
logging.info(f"[llm time]: {((time.time()-lm_time)*1000):.2f} ms, [hit_eos]: {hit_eos}, [gen len]: {codec.shape[1]}, [speech tokens]: {codec[0].cpu().tolist()}")
wav, batch_data_time = None, 1.0
if codec.shape[1] > 0:
fm_time = time.time()
data_in[0]["codec"] = codec
set_all_random_seed(kwargs.get("random_seed", 0))
feat = self.fm_model.inference(data_in, data_lengths, key, **kwargs)
# feat -> wav
set_all_random_seed(kwargs.get("random_seed", 0))
wav = self.voc_model.inference([feat[0]], data_lengths, key, **kwargs)
# output save
output_dir = kwargs.get("output_dir", None)
if output_dir is not None:
feat_out_dir = os.path.join(output_dir, "feat")
os.makedirs(feat_out_dir, exist_ok=True)
np.save(os.path.join(feat_out_dir, f"{key[0]}.npy"), feat[0].cpu().numpy())
wav_out_dir = os.path.join(output_dir, "wav")
os.makedirs(wav_out_dir, exist_ok=True)
output_wav_path = os.path.join(wav_out_dir, f"{key[0]}.wav")
torchaudio.save(
output_wav_path, wav.cpu(),
sample_rate=self.sample_rate, encoding='PCM_S', bits_per_sample=16
)
silent_video_path = data_in[0]["video"]
if os.path.exists(silent_video_path):
video_out_dir = os.path.join(output_dir, "mp4")
video_gt_dir = os.path.join(output_dir, "gt")
os.makedirs(video_out_dir, exist_ok=True)
os.makedirs(video_gt_dir, exist_ok=True)
output_video_path = os.path.join(video_out_dir, f"{key[0]}.mp4")
copy_video_path = os.path.join(video_gt_dir, f"{key[0]}.mp4")
shutil.copy2(silent_video_path, copy_video_path)
self.merge_video_audio(
silent_video_path=silent_video_path,
wav_path=output_wav_path,
output_path=output_video_path,
)
logging.info(f"fm_voc time: {((time.time()-fm_time)*1000):.2f} ms")
batch_data_time = wav.shape[1] / self.voc_model.sample_rate
return [[wav]], {"batch_data_time": batch_data_time}
def merge_video_audio(self, silent_video_path, wav_path, output_path):
video_clip = VideoFileClip(silent_video_path)
video_duration = video_clip.duration
audio_clip = AudioFileClip(wav_path)
audio_duration = audio_clip.duration
if audio_duration >= video_duration:
audio_clip = audio_clip.subclipped(0, video_duration)
video_clip = video_clip.with_audio(audio_clip)
video_clip.write_videofile(
output_path,
codec='libx264',
audio_codec='aac',
fps=video_clip.fps,
logger=None
)
video_clip.close()
audio_clip.close()