File size: 4,707 Bytes
03022ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import torch
import torch.nn as nn
import logging
import numpy as np
import os
import torchaudio
import time
import shutil
from funcineforge.utils.set_all_random_seed import set_all_random_seed
from moviepy.video.io.VideoFileClip import VideoFileClip, AudioFileClip


class FunCineForgeInferModel(nn.Module):
    def __init__(
        self,
        lm_model,
        fm_model,
        voc_model,
        **kwargs
    ):
        from funcineforge.auto.auto_model import AutoModel
        super().__init__()
        self.tokenizer = lm_model.kwargs["tokenizer"]
        self.frontend = fm_model.kwargs["frontend"]
        self.lm_model = lm_model.model
        self.fm_model = fm_model.model
        self.voc_model = voc_model.model
        mel_extractor = self.fm_model.mel_extractor
        if mel_extractor:
            self.mel_frame_rate = mel_extractor.sampling_rate // mel_extractor.hop_length
            self.sample_rate = mel_extractor.sampling_rate
        else:
            self.mel_frame_rate = self.fm_model.sample_rate // 480
            self.sample_rate = self.fm_model.sample_rate

    @torch.no_grad()
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        **kwargs,
    ):
        uttid = key[0]
        logging.info(f"generating {uttid}")
        # text -> codec in [1, T]
        kwargs["tokenizer"] = self.tokenizer
        set_all_random_seed(kwargs.get("random_seed", 0))
        lm_time = time.time()
        codec, hit_eos, states = self.lm_model.inference(data_in, data_lengths, key, **kwargs)
        logging.info(f"[llm time]: {((time.time()-lm_time)*1000):.2f} ms, [hit_eos]: {hit_eos}, [gen len]: {codec.shape[1]}, [speech tokens]: {codec[0].cpu().tolist()}")
        wav, batch_data_time = None, 1.0
        if codec.shape[1] > 0:
            fm_time = time.time()
            data_in[0]["codec"] = codec
            set_all_random_seed(kwargs.get("random_seed", 0))
            feat = self.fm_model.inference(data_in, data_lengths, key, **kwargs)
            # feat -> wav
            set_all_random_seed(kwargs.get("random_seed", 0))
            wav = self.voc_model.inference([feat[0]], data_lengths, key, **kwargs)
            # output save
            output_dir = kwargs.get("output_dir", None)
            if output_dir is not None:
                feat_out_dir = os.path.join(output_dir, "feat")
                os.makedirs(feat_out_dir, exist_ok=True)
                np.save(os.path.join(feat_out_dir, f"{key[0]}.npy"), feat[0].cpu().numpy())

                wav_out_dir = os.path.join(output_dir, "wav")
                os.makedirs(wav_out_dir, exist_ok=True)
                output_wav_path = os.path.join(wav_out_dir, f"{key[0]}.wav")
                torchaudio.save(
                    output_wav_path, wav.cpu(),
                    sample_rate=self.sample_rate, encoding='PCM_S', bits_per_sample=16
                )
                
                silent_video_path = data_in[0]["video"]
                if os.path.exists(silent_video_path):
                    video_out_dir = os.path.join(output_dir, "mp4")
                    video_gt_dir = os.path.join(output_dir, "gt")
                    os.makedirs(video_out_dir, exist_ok=True)
                    os.makedirs(video_gt_dir, exist_ok=True)
                    output_video_path = os.path.join(video_out_dir, f"{key[0]}.mp4")
                    copy_video_path = os.path.join(video_gt_dir, f"{key[0]}.mp4")
                    shutil.copy2(silent_video_path, copy_video_path)
                    self.merge_video_audio(
                        silent_video_path=silent_video_path,
                        wav_path=output_wav_path,
                        output_path=output_video_path,
                    )
                
            logging.info(f"fm_voc time: {((time.time()-fm_time)*1000):.2f} ms")

            batch_data_time = wav.shape[1] / self.voc_model.sample_rate

        return [[wav]], {"batch_data_time": batch_data_time}
    
    def merge_video_audio(self, silent_video_path, wav_path, output_path):
        
        video_clip = VideoFileClip(silent_video_path)
        video_duration = video_clip.duration
        audio_clip = AudioFileClip(wav_path)
        audio_duration = audio_clip.duration
        
        if audio_duration >= video_duration:
            audio_clip = audio_clip.subclipped(0, video_duration)
        
        video_clip = video_clip.with_audio(audio_clip)
        video_clip.write_videofile(
            output_path,
            codec='libx264',
            audio_codec='aac',
            fps=video_clip.fps,
            logger=None
        )
        video_clip.close()
        audio_clip.close()