File size: 7,175 Bytes
03022ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
import os
import shutil
import json
import re
from pydub import AudioSegment
from moviepy.video.io.VideoFileClip import VideoFileClip
from speaker_diarization.local.vision_processer import VisionProcesser
import wave


# ==================== 工具函数 ====================

def get_video_duration(video_path):
    """获取视频时长(秒)"""
    try:
        clip = VideoFileClip(video_path)
        duration = clip.duration
        clip.close()
        return duration
    except Exception as e:
        return 0.0
    
def extract_audio_from_video(video_path: str, wav_path: str, sample_rate: int = 16000):
    """Extract mono 16kHz WAV from video."""
    print(f"[INFO] Extracting audio from {video_path} to {wav_path}")
    audio = AudioSegment.from_file(video_path)
    audio = audio.set_frame_rate(sample_rate).set_channels(1)
    audio.export(wav_path, format="wav")


def extract_visual_embeddings(frontend, vad_list, video_path, wav_path, pkl_path):
    try:
        vp = VisionProcesser(
            video_file_path = video_path,
            audio_file_path = wav_path,
            audio_vad = vad_list,
            out_feat_path = pkl_path,
            visual_models = frontend,
            conf = frontend.conf,
            out_video_path=None
        )
        vp.run()
    except Exception as e:
        print(f"[ERROR] Failed to process {video_path}: {e}")
        raise
    finally:
        if 'vp' in locals():
            vp.close()
    return


def detect_video_type(video_path):
    """【占位函数】检测视频类型"""
    return "独白"

def clip_video_segment(video_path, start_time, end_time, output_dir, clip_name):
    """裁切视频片段"""
    try:
        video_clip = os.path.join(output_dir, f"{clip_name}.mp4")
        audio_clip = os.path.join(output_dir, f"{clip_name}.wav")
        clip = VideoFileClip(video_path).subclipped(start_time, end_time)
        clip.write_videofile(video_clip, codec="libx264", audio_codec='aac', logger=None)
        if clip.audio is not None:
            clip.audio.write_audiofile(audio_clip, codec="pcm_s16le", logger=None)
        else:
            num_samples = int(16000 * (end_time - start_time))
            with wave.open(audio_clip, 'wb') as wf:
                wf.setnchannels(1)
                wf.setsampwidth(2)  # 16-bit
                wf.setframerate(16000)
                wf.writeframes(b'\x00' * (num_samples * 2))
        clip.close()
        return video_clip, audio_clip
    except Exception as e:
        return None

def generate_jsonl_data(frontend, video_path, segments_data, work_dir, video_duration):
    """生成 JSONL 格式数据"""
    video_type = detect_video_type(video_path)
    
    jsonl_items = []
    
    for idx, seg in enumerate(segments_data):
        utt_name = f"clip_{idx}"
        start, end = max(0.0, float(seg['start']) - 0.1), min(float(seg['end']) + 0.1, video_duration)
        duration = end - start
        
        
        video_clip, audio_clip = clip_video_segment(
            video_path, start, end, 
            work_dir, utt_name
        )
        if not video_clip or not audio_clip:
            continue
        
        pkl_path = os.path.join(work_dir, f"{utt_name}.pkl")
        
        extract_visual_embeddings(
            frontend,
            vad_list = [[0.0, round(duration, 2)]],
            video_path = video_clip, 
            wav_path = audio_clip, 
            pkl_path = pkl_path
        )
        
        ref_audio_path = audio_clip
        if seg.get('ref_audio') and os.path.exists(seg['ref_audio']):
            src = seg['ref_audio']
            dst = os.path.join(work_dir, f"{utt_name}_ref.wav")
            shutil.copy(src, dst)
            ref_audio_path = dst
        
        item = {
            "messages": [
                {"role": "text", "content": seg['text']},
                {"role": "vocal", "content": ref_audio_path},
                {"role": "video", "content": video_clip},
                {"role": "face", "content": pkl_path},
                {"role": "dialogue", "content": [{
                    "start": 0.0,
                    "duration": round(duration, 2),
                    "spk": "1",
                    "gender": seg['gender'],
                    "age": seg['age']
                }]},
                {"role": "clue", "content": seg['clue']}
            ],
            "utt": utt_name,
            "type": video_type,
            "speech_length": int(duration * 25),
            "start": start,
            "end": end
        }
        jsonl_items.append(item)
    
    jsonl_path = os.path.join(work_dir, "input_data.jsonl")
    with open(jsonl_path, 'w', encoding='utf-8') as f:
        for item in jsonl_items:
            f.write(json.dumps(item, ensure_ascii=False) + '\n')
    
    return jsonl_path, jsonl_items

def validate_timestamps(start, end, video_duration):
    """验证时间戳合法性"""
    errors = []
    if start < 0:
        errors.append(f"起始时间 ({start}s) 不能小于 0")
    if end > video_duration:
        errors.append(f"终止时间 ({end}s) 不能大于视频总时长 ({video_duration}s)")
    duration = end - start
    if duration <= 0:
        errors.append(f"起始时间 ({start}s) 必须小于终止时间 ({end}s)")
    if duration >= 0 and duration <= 2:
        errors.append(f"配音时长 ({duration}s) 太短,必须大于 2s")
    if duration >= 30:
        errors.append(f"配音时长 ({duration}s) 太长,请小于 30s")
    return errors

#=== SRT 解析 ====
def parse_srt_time(time_str: str) -> float:
    """将 SRT 时间字符串 (HH:MM:SS,mmm) 转换为秒"""
    time_str = time_str.strip().replace(',', '.')
    h, m, s = time_str.split(':')
    s_part, ms = s.split('.')
    return int(h) * 3600 + int(m) * 60 + int(s_part) + int(ms) / 1000.0

def parse_srt_content(srt_text: str) -> list:
    """解析 SRT 文本"""
    if not srt_text:
        return []
    lines = srt_text.replace('\r\n', '\n').replace('\r', '\n').strip().split('\n')
    segments = []
    n = len(lines)
    i = 0
    while i < n:
        line = lines[i].strip()
        if re.match(r'^\d+(\s+spk\d+)?$', line):
            if i + 1 < n:
                time_match = re.search(r'(\d{2}:\d{2}:\d{2}[,.]\d{3})\s*-->\s*(\d{2}:\d{2}:\d{2}[,.]\d{3})', lines[i+1])
                if time_match:
                    start = parse_srt_time(time_match.group(1))
                    end = parse_srt_time(time_match.group(2))
                    # 收集文本,直到遇到下一个序号行或文件结束
                    text_parts = []
                    j = i + 2
                    while j < n and not re.match(r'^\d+(\s+spk\d+)?$', lines[j].strip()):
                        if lines[j].strip():
                            text_parts.append(lines[j].strip())
                        j += 1
                        
                    text = ' '.join(text_parts)
                    if text:
                        segments.append({"start": start, "end": end, "text": text})
                    i = j  # 跳过已处理的块,直接进入下一轮
                    continue
        i += 1
    return segments