import torch import torch.nn as nn import logging import numpy as np import os import torchaudio import time import shutil from funcineforge.utils.set_all_random_seed import set_all_random_seed from moviepy.video.io.VideoFileClip import VideoFileClip, AudioFileClip class FunCineForgeInferModel(nn.Module): def __init__( self, lm_model, fm_model, voc_model, **kwargs ): from funcineforge.auto.auto_model import AutoModel super().__init__() self.tokenizer = lm_model.kwargs["tokenizer"] self.frontend = fm_model.kwargs["frontend"] self.lm_model = lm_model.model self.fm_model = fm_model.model self.voc_model = voc_model.model mel_extractor = self.fm_model.mel_extractor if mel_extractor: self.mel_frame_rate = mel_extractor.sampling_rate // mel_extractor.hop_length self.sample_rate = mel_extractor.sampling_rate else: self.mel_frame_rate = self.fm_model.sample_rate // 480 self.sample_rate = self.fm_model.sample_rate @torch.no_grad() def inference( self, data_in, data_lengths=None, key: list = None, **kwargs, ): uttid = key[0] logging.info(f"generating {uttid}") # text -> codec in [1, T] kwargs["tokenizer"] = self.tokenizer set_all_random_seed(kwargs.get("random_seed", 0)) lm_time = time.time() codec, hit_eos, states = self.lm_model.inference(data_in, data_lengths, key, **kwargs) logging.info(f"[llm time]: {((time.time()-lm_time)*1000):.2f} ms, [hit_eos]: {hit_eos}, [gen len]: {codec.shape[1]}, [speech tokens]: {codec[0].cpu().tolist()}") wav, batch_data_time = None, 1.0 if codec.shape[1] > 0: fm_time = time.time() data_in[0]["codec"] = codec set_all_random_seed(kwargs.get("random_seed", 0)) feat = self.fm_model.inference(data_in, data_lengths, key, **kwargs) # feat -> wav set_all_random_seed(kwargs.get("random_seed", 0)) wav = self.voc_model.inference([feat[0]], data_lengths, key, **kwargs) # output save output_dir = kwargs.get("output_dir", None) if output_dir is not None: feat_out_dir = os.path.join(output_dir, "feat") os.makedirs(feat_out_dir, exist_ok=True) np.save(os.path.join(feat_out_dir, f"{key[0]}.npy"), feat[0].cpu().numpy()) wav_out_dir = os.path.join(output_dir, "wav") os.makedirs(wav_out_dir, exist_ok=True) output_wav_path = os.path.join(wav_out_dir, f"{key[0]}.wav") torchaudio.save( output_wav_path, wav.cpu(), sample_rate=self.sample_rate, encoding='PCM_S', bits_per_sample=16 ) silent_video_path = data_in[0]["video"] if os.path.exists(silent_video_path): video_out_dir = os.path.join(output_dir, "mp4") video_gt_dir = os.path.join(output_dir, "gt") os.makedirs(video_out_dir, exist_ok=True) os.makedirs(video_gt_dir, exist_ok=True) output_video_path = os.path.join(video_out_dir, f"{key[0]}.mp4") copy_video_path = os.path.join(video_gt_dir, f"{key[0]}.mp4") shutil.copy2(silent_video_path, copy_video_path) self.merge_video_audio( silent_video_path=silent_video_path, wav_path=output_wav_path, output_path=output_video_path, ) logging.info(f"fm_voc time: {((time.time()-fm_time)*1000):.2f} ms") batch_data_time = wav.shape[1] / self.voc_model.sample_rate return [[wav]], {"batch_data_time": batch_data_time} def merge_video_audio(self, silent_video_path, wav_path, output_path): video_clip = VideoFileClip(silent_video_path) video_duration = video_clip.duration audio_clip = AudioFileClip(wav_path) audio_duration = audio_clip.duration if audio_duration >= video_duration: audio_clip = audio_clip.subclipped(0, video_duration) video_clip = video_clip.with_audio(audio_clip) video_clip.write_videofile( output_path, codec='libx264', audio_codec='aac', fps=video_clip.fps, logger=None ) video_clip.close() audio_clip.close()