Fun-CineForge-Demo / utils.py
xuan3986's picture
Upload 111 files
03022ee verified
import os
import shutil
import json
import re
from pydub import AudioSegment
from moviepy.video.io.VideoFileClip import VideoFileClip
from speaker_diarization.local.vision_processer import VisionProcesser
import wave
# ==================== 工具函数 ====================
def get_video_duration(video_path):
"""获取视频时长(秒)"""
try:
clip = VideoFileClip(video_path)
duration = clip.duration
clip.close()
return duration
except Exception as e:
return 0.0
def extract_audio_from_video(video_path: str, wav_path: str, sample_rate: int = 16000):
"""Extract mono 16kHz WAV from video."""
print(f"[INFO] Extracting audio from {video_path} to {wav_path}")
audio = AudioSegment.from_file(video_path)
audio = audio.set_frame_rate(sample_rate).set_channels(1)
audio.export(wav_path, format="wav")
def extract_visual_embeddings(frontend, vad_list, video_path, wav_path, pkl_path):
try:
vp = VisionProcesser(
video_file_path = video_path,
audio_file_path = wav_path,
audio_vad = vad_list,
out_feat_path = pkl_path,
visual_models = frontend,
conf = frontend.conf,
out_video_path=None
)
vp.run()
except Exception as e:
print(f"[ERROR] Failed to process {video_path}: {e}")
raise
finally:
if 'vp' in locals():
vp.close()
return
def detect_video_type(video_path):
"""【占位函数】检测视频类型"""
return "独白"
def clip_video_segment(video_path, start_time, end_time, output_dir, clip_name):
"""裁切视频片段"""
try:
video_clip = os.path.join(output_dir, f"{clip_name}.mp4")
audio_clip = os.path.join(output_dir, f"{clip_name}.wav")
clip = VideoFileClip(video_path).subclipped(start_time, end_time)
clip.write_videofile(video_clip, codec="libx264", audio_codec='aac', logger=None)
if clip.audio is not None:
clip.audio.write_audiofile(audio_clip, codec="pcm_s16le", logger=None)
else:
num_samples = int(16000 * (end_time - start_time))
with wave.open(audio_clip, 'wb') as wf:
wf.setnchannels(1)
wf.setsampwidth(2) # 16-bit
wf.setframerate(16000)
wf.writeframes(b'\x00' * (num_samples * 2))
clip.close()
return video_clip, audio_clip
except Exception as e:
return None
def generate_jsonl_data(frontend, video_path, segments_data, work_dir, video_duration):
"""生成 JSONL 格式数据"""
video_type = detect_video_type(video_path)
jsonl_items = []
for idx, seg in enumerate(segments_data):
utt_name = f"clip_{idx}"
start, end = max(0.0, float(seg['start']) - 0.1), min(float(seg['end']) + 0.1, video_duration)
duration = end - start
video_clip, audio_clip = clip_video_segment(
video_path, start, end,
work_dir, utt_name
)
if not video_clip or not audio_clip:
continue
pkl_path = os.path.join(work_dir, f"{utt_name}.pkl")
extract_visual_embeddings(
frontend,
vad_list = [[0.0, round(duration, 2)]],
video_path = video_clip,
wav_path = audio_clip,
pkl_path = pkl_path
)
ref_audio_path = audio_clip
if seg.get('ref_audio') and os.path.exists(seg['ref_audio']):
src = seg['ref_audio']
dst = os.path.join(work_dir, f"{utt_name}_ref.wav")
shutil.copy(src, dst)
ref_audio_path = dst
item = {
"messages": [
{"role": "text", "content": seg['text']},
{"role": "vocal", "content": ref_audio_path},
{"role": "video", "content": video_clip},
{"role": "face", "content": pkl_path},
{"role": "dialogue", "content": [{
"start": 0.0,
"duration": round(duration, 2),
"spk": "1",
"gender": seg['gender'],
"age": seg['age']
}]},
{"role": "clue", "content": seg['clue']}
],
"utt": utt_name,
"type": video_type,
"speech_length": int(duration * 25),
"start": start,
"end": end
}
jsonl_items.append(item)
jsonl_path = os.path.join(work_dir, "input_data.jsonl")
with open(jsonl_path, 'w', encoding='utf-8') as f:
for item in jsonl_items:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
return jsonl_path, jsonl_items
def validate_timestamps(start, end, video_duration):
"""验证时间戳合法性"""
errors = []
if start < 0:
errors.append(f"起始时间 ({start}s) 不能小于 0")
if end > video_duration:
errors.append(f"终止时间 ({end}s) 不能大于视频总时长 ({video_duration}s)")
duration = end - start
if duration <= 0:
errors.append(f"起始时间 ({start}s) 必须小于终止时间 ({end}s)")
if duration >= 0 and duration <= 2:
errors.append(f"配音时长 ({duration}s) 太短,必须大于 2s")
if duration >= 30:
errors.append(f"配音时长 ({duration}s) 太长,请小于 30s")
return errors
#=== SRT 解析 ====
def parse_srt_time(time_str: str) -> float:
"""将 SRT 时间字符串 (HH:MM:SS,mmm) 转换为秒"""
time_str = time_str.strip().replace(',', '.')
h, m, s = time_str.split(':')
s_part, ms = s.split('.')
return int(h) * 3600 + int(m) * 60 + int(s_part) + int(ms) / 1000.0
def parse_srt_content(srt_text: str) -> list:
"""解析 SRT 文本"""
if not srt_text:
return []
lines = srt_text.replace('\r\n', '\n').replace('\r', '\n').strip().split('\n')
segments = []
n = len(lines)
i = 0
while i < n:
line = lines[i].strip()
if re.match(r'^\d+(\s+spk\d+)?$', line):
if i + 1 < n:
time_match = re.search(r'(\d{2}:\d{2}:\d{2}[,.]\d{3})\s*-->\s*(\d{2}:\d{2}:\d{2}[,.]\d{3})', lines[i+1])
if time_match:
start = parse_srt_time(time_match.group(1))
end = parse_srt_time(time_match.group(2))
# 收集文本,直到遇到下一个序号行或文件结束
text_parts = []
j = i + 2
while j < n and not re.match(r'^\d+(\s+spk\d+)?$', lines[j].strip()):
if lines[j].strip():
text_parts.append(lines[j].strip())
j += 1
text = ' '.join(text_parts)
if text:
segments.append({"start": start, "end": end, "text": text})
i = j # 跳过已处理的块,直接进入下一轮
continue
i += 1
return segments