xuan3986's picture
Update app.py
77a2d30 verified
# app.py
import os
import json
import torch
import gradio as gr
import typing
import time
import shutil
from moviepy.video.io.VideoFileClip import VideoFileClip, AudioFileClip
from moviepy.audio.AudioClip import CompositeAudioClip
from huggingface_hub import snapshot_download
from utils import get_video_duration, generate_jsonl_data, validate_timestamps, parse_srt_content
# 尝试导入模型库
from funcineforge import AutoFrontend
from speaker_diarization.run import GlobalModels
snapshot_download(
repo_id="FunAudioLLM/Fun-CineForge",
local_dir='pretrained_models',
resume_download=True,
force_download=False,
ignore_patterns=[
"*.md",
".git*",
"funcineforge_zh_en/llm/config.yaml"
],
token=None,
repo_type="model",
)
# ==================== 配置区域 ====================
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
SERVER_PORT = 7860
TEMP_DIR = "temp_workdir"
CONFIG_FRONTEND = "decode_conf/diar.yaml"
CONFIG_MODEL = "decode_conf/decode.yaml"
PRETRAIN = "pretrained_models"
MAX_SEGMENTS = 8 # UI 片段数上限
DEFAULT_VIDEO_PATH="data/sample.mp4"
DEFAULT_AUDIO_PATH="data/ref.wav"
DEFAULT_TEXT = "我军无粮,利在急战。今乘魏兵新败,不敢出兵,出其不意,乘机退去,方可平安无事。"
DEFAULT_CLUE = "一位中年男性以沉稳但略带担忧的语调,分析我军无粮急战的困境与敌军心败状态。他随即提出一种撤退方案,整体流露出对战局的担忧和谋求生路。"
# 全局模型实例(延迟加载)
model_pool: typing.Optional[GlobalModels] = None
engine = None
def init_engine():
"""延迟加载模型,避免启动时卡住"""
global engine
engine = AutoFrontend(PRETRAIN, CONFIG_MODEL, TEMP_DIR, DEVICE)
return engine
def init_frontend_models():
global model_pool
model_pool = GlobalModels(
hf_token = None,
config_path = CONFIG_FRONTEND,
pretrained_dir= PRETRAIN,
device = DEVICE,
pool_sizes = {"face": 1, "asd": 1, "fr": 1},
batch_size = 1,
preload = True
)
return model_pool
# ==================== Gradio UI 逻辑 ====================
def create_segments_ui():
segments = []
accordions = []
for i in range(MAX_SEGMENTS):
with gr.Accordion(f"🎬 Dubbing clip {i + 1}", open=(i == 0), visible=(i == 0)) as acc:
accordions.append(acc)
with gr.Row():
text_input = gr.Textbox(label="📝 Dubbing script", placeholder="Please enter the script...", lines=2, scale=3, elem_id=f"text_{i}")
clue_input = gr.Textbox(label="💡 Clue description", placeholder="A middle-aged male character speaks with a calm and firm tone, revealing a strong confidence and determination in his own loyalty. The overall emotion conveys an unwavering commitment and an unquestionable belief.", lines=2, scale=3, elem_id=f"clue_{i}")
with gr.Row():
start_time = gr.Number(label="⏱️ Start timestamp (s)", value=0.0 + i*5, precision=2, scale=2, elem_id=f"start_{i}")
end_time = gr.Number(label="⏱️ End timestamp (s)", value=5.0 + i*5, precision=2, scale=2, elem_id=f"end_{i}")
with gr.Row():
age_input = gr.Dropdown(label="👤 Age", choices=["child", "teenager", "adult", "middle-aged", "elderly", "unknown"], value="unknown", scale=2, elem_id=f"age_{i}")
gender_input = gr.Dropdown(label="👤 Gender", choices=["male", "female", "unknown"], value="unknown", scale=2, elem_id=f"gender_{i}")
with gr.Row():
ref_audio = gr.Audio(label="🎤 Reference audio (optional, the video's audio is used as the reference audio by default).", sources=["upload"], type="filepath", scale=4,elem_id=f"audio_{i}")
load_audio_btn = gr.Button("📂 Load sample audio", size="sm", variant="secondary", scale=1) if i == 0 else None
with gr.Row():
enable_check = gr.Checkbox(label="Enable this clip", value=(i == 0), scale=1, elem_id=f"enable_{i}")
segments.append({
"accordion": acc, "text": text_input, "clue": clue_input, "start": start_time, "end": end_time,
"age": age_input, "gender": gender_input, "audio": ref_audio,
"enable": enable_check, "index": i, "load_audio_btn": load_audio_btn})
return segments, accordions
def add_segment_fn(current_count):
"""点击加号:显示下一个片段,到达上限则禁用按钮"""
if current_count >= MAX_SEGMENTS:
return [current_count] + [gr.update() for _ in range(MAX_SEGMENTS)] + [gr.update(interactive=False, value=f"The limit has been reached. ({MAX_SEGMENTS})")]
new_count = current_count + 1
vis = [gr.update(visible=(i < new_count)) for i in range(MAX_SEGMENTS)]
btn = gr.update(interactive=(new_count < MAX_SEGMENTS), value="➕ New clip")
return [new_count] + vis + [btn]
def load_srt_fn(srt_file, current_count):
empty_fields = [gr.update() for _ in range(MAX_SEGMENTS * 4)]
empty_vis = [gr.update() for _ in range(MAX_SEGMENTS)]
if not srt_file:
return [current_count] + empty_fields + empty_vis + [gr.update()]
try:
with open(srt_file, 'r', encoding='utf-8-sig') as f:
content = f.read()
except Exception as e:
gr.Warning(f"Failed to read SRT file: {e}")
return [current_count] + empty_fields + empty_vis + [gr.update()]
parsed = parse_srt_content(content)
if not parsed:
print(" No valid subtitles were parsed. Please check the SRT format.")
return [current_count] + empty_fields + empty_vis + [gr.update()]
updates = []
for i in range(MAX_SEGMENTS):
if i < len(parsed):
seg = parsed[i]
updates.append(gr.update(value=seg['text']))
updates.append(gr.update(value=round(seg['start'], 2)))
updates.append(gr.update(value=round(seg['end'], 2)))
updates.append(gr.update(value=True))
else:
updates.append(gr.update(value=""))
updates.append(gr.update(value=0.0))
updates.append(gr.update(value=5.0 + i*5))
updates.append(gr.update(value=False))
new_count = min(len(parsed), MAX_SEGMENTS)
vis = [gr.update(visible=(i < new_count)) for i in range(MAX_SEGMENTS)]
btn = gr.update(interactive=(new_count < MAX_SEGMENTS))
if len(parsed) > MAX_SEGMENTS:
gr.Warning(f"The SRT contains {len(parsed)} fragments, of which the first {MAX_SEGMENTS} have been truncated.")
return [new_count] + updates + vis + [btn]
def process_dubbing(video_file, *segment_inputs, progress=gr.Progress()):
"""主推理流程"""
if not video_file:
return None, "❌ Please upload the video file."
video_duration = get_video_duration(video_file)
if video_duration <= 0:
return None, "❌ Unable to obtain video duration, please check the video file."
if os.path.exists(TEMP_DIR):
try:
shutil.rmtree(TEMP_DIR)
except Exception as e:
return None, f"❌ Failed to clear temporary directory:{e}"
os.makedirs(TEMP_DIR, exist_ok=True)
# 解析 segment_inputs
segments_data = []
for i in range(MAX_SEGMENTS):
base_idx = i * 8
enable = segment_inputs[base_idx + 7] # enable_check
if not enable: continue
text = segment_inputs[base_idx + 0]
if not text or not text.strip(): continue
clue = segment_inputs[base_idx + 1]
start = segment_inputs[base_idx + 2]
end = segment_inputs[base_idx + 3]
age = segment_inputs[base_idx + 4]
gender = segment_inputs[base_idx + 5]
ref_audio = segment_inputs[base_idx + 6]
errors = validate_timestamps(start, end, video_duration)
if errors:
return None, f"❌ Clip {i+1} timestamp error:\n" + "\n".join(errors)
data = {
"text": str(text).strip(),
"clue": str(clue) if clue else "",
"start": float(start) if start else 0.0,
"end": float(end) if end else 0.0,
"age": str(age) if age else "unknown",
"gender": str(gender) if gender else "unknown",
"ref_audio": str(ref_audio) if ref_audio else ""
}
segments_data.append(data)
if not segments_data:
return None, "❌ The valid clip data is empty. Please enable and fill in at least one clip."
try:
progress(0.1, desc="📋 Preprocess the video to generate JSONL data...")
frontend = init_frontend_models()
jsonl_path, jsonl_items = generate_jsonl_data(frontend, video_file, segments_data, TEMP_DIR, video_duration)
report_lines = [f"✅ Task completed! A total of **{len(jsonl_items)}** data fragments were generated.\n", "Detailed JSONL data preview:**", "=" * 40]
for idx, item in enumerate(jsonl_items):
report_lines.extend([f"\n---Clip #{idx + 1} ---", json.dumps(item, ensure_ascii=False, indent=2), "-" * 40])
full_report = "\n".join(report_lines)
progress(0.3, desc="🔄 FunCineForge dubbing model loading...")
eng = init_engine()
if eng and jsonl_items:
try:
progress(0.5, desc="🚀 FunCineForge dubbing model inference...")
eng.inference(jsonl_path)
progress(0.8, desc="🎵 Pasting the voiceover back into the muted video...")
output_wav_dir = os.path.join(TEMP_DIR, "wav")
final_video_path = os.path.join(TEMP_DIR, "dubbed_video.mp4")
if not os.path.exists(output_wav_dir):
return None, f"⚠️ Audio output directory not found:{output_wav_dir}"
wav_files = sorted([f for f in os.listdir(output_wav_dir) if f.endswith('.wav')])
if not wav_files:
return None, f"⚠️ No audio files were generated:{output_wav_dir}"
time_mapping = {}
for item in jsonl_items:
for wf in wav_files:
if wf.startswith(item['utt']):
time_mapping[wf] = float(item['start'])
break
original_clip = VideoFileClip(video_file)
video_duration = original_clip.duration
is_silent = original_clip.audio is None
video_only = original_clip if is_silent else original_clip.without_audio()
audio_clips = []
for wav_file, start_time in time_mapping.items():
wav_path = os.path.join(output_wav_dir, wav_file)
audio_clip = AudioFileClip(wav_path).with_start(start_time)
audio_clips.append(audio_clip)
final_audio = CompositeAudioClip(audio_clips)
if final_audio.duration < video_duration:
final_audio = final_audio.with_duration(video_duration)
final_clip = video_only.with_audio(final_audio)
final_clip.write_videofile(
final_video_path,
codec='libx264',
audio_codec='aac',
preset='veryfast',
threads=8,
fps=original_clip.fps,
logger=None
)
original_clip.close(); video_only.close()
for ac in audio_clips: ac.close()
if 'final_audio' in locals(): final_audio.close()
final_clip.close()
progress(1.0, desc="✅ Dubbing complete")
return final_video_path, full_report
except Exception as e:
import traceback; traceback.print_exc()
if "index out of range" in str(e):
return None, f"⚠️ Model inference failed. Error: {str(e)}. It is recommended to complete the input clue description and speaker attributes."
else:
return None, f"⚠️ Model inference failed. Error: {str(e)}"
else:
time.sleep(1)
progress(1.0, desc="Simulation complete")
return video_file, full_report
except Exception as e:
import traceback; traceback.print_exc()
return None, f"❌ Error: {str(e)}"
# ==================== 主程序 ====================
def main():
os.makedirs(TEMP_DIR, exist_ok=True)
with gr.Blocks(
title="Fun-CineForge-Demo",
theme=gr.themes.Soft(),
css="""
.segment-accordion { margin: 10px 0; }
.gr-button-primary { background: #1976d2; }
.gr-button-stop { background: #d32f2f; }
"""
) as demo:
gr.Markdown("""
# 🎬 Fun-CineForge
**Workflow:** Upload short video → Add clip information (or upload .srt subtitle file) → Upload reference audio (optional) → Preprocessing, model loading, and inference → Output dubbed video
""")
with gr.Row():
with gr.Column(scale=1):
video_input = gr.Video(label="Upload video", sources=["upload"])
load_video_btn = gr.Button("📂 Load sample video", variant="secondary", size="sm")
srt_input = gr.UploadButton("Upload SRT subtitles", file_types=[".srt"], size="sm", variant="secondary")
# with gr.Row(elem_classes=["srt-compact"]):
# srt_input = gr.File(label="上传 SRT 字幕", file_types=[".srt"], height="auto")
gr.Markdown("### 🎛️ Dubbing clip configuration")
segments, accordions = create_segments_ui()
seg_count_state = gr.State(1) #🔑记录当前可见片段数
add_segment_btn = gr.Button("➕Add new clip", size="sm", variant="secondary")
submit_btn = gr.Button("🚀 Start dubbing", variant="stop", size="lg")
with gr.Column(scale=1):
video_output = gr.Video(label="📺 Dubbed video", autoplay=True)
status_text = gr.Textbox(label="Result status", interactive=False, lines=2)
gr.Markdown("""
### 📝 Instructions for use
| Fields | Descriptions |
|------|------|
| Dubbing script | The content of this clip (supports Chinese/English) |
| Clue description | Please refer to the sample format to explain the dubbing requirements, focusing on describing the speaker's gender, age, tone, and emotion |
| Timestamps | Start and end timestamps (accurate to milliseconds). The model is sensitive to timestamps; it is recommended to use timestamps adjacent to the audio clip. Duration ≤ 30s/clip |
| Age/Gender | Speaker attribute options |
| Reference audio | Voice cloning reference (Optional) |
**⚠️ Note:** Ensure that the timestamps of each clip don't overlap and don't exceed the video duration. The model will perform time alignment based on the timestamps, with weak supervision aligning lip movements.
""")
# ==================== 事件绑定 ====================
# 收集所有片段组件作为输入
segment_inputs = []
for seg in segments:
segment_inputs.extend([
seg["text"],
seg["clue"],
seg["start"],
seg["end"],
seg["age"],
seg["gender"],
seg["audio"],
seg["enable"]
])
srt_update_fields = []
for seg in segments:
srt_update_fields.extend([seg["text"], seg["start"], seg["end"], seg["enable"]])
# 动态添加片段
add_segment_btn.click(
fn=add_segment_fn,
inputs=[seg_count_state],
outputs=[seg_count_state] + accordions + [add_segment_btn]
)
# SRT 加载
srt_input.upload(
fn=load_srt_fn,
inputs=[srt_input, seg_count_state],
outputs=[seg_count_state] + srt_update_fields + accordions + [add_segment_btn]
)
# 主推理
submit_btn.click(
fn=process_dubbing,
inputs=[video_input] + segment_inputs,
outputs=[video_output, status_text]
)
# 视频上传联动时间戳
def update_timestamps(video):
if not video: return [gr.update() for _ in range(MAX_SEGMENTS * 2)]
dur = get_video_duration(video)
updates = []
for i in range(MAX_SEGMENTS):
updates.append(gr.update(value=0.0))
updates.append(gr.update(value=dur))
return updates
def load_default_video_fn():
return DEFAULT_VIDEO_PATH, DEFAULT_TEXT, DEFAULT_CLUE
def load_default_audio_fn():
return DEFAULT_AUDIO_PATH
load_video_btn.click(
fn=load_default_video_fn,
inputs=[],
outputs=[video_input, segments[0]["text"], segments[0]["clue"]]
).then(
fn=update_timestamps,
inputs=[video_input],
outputs=[segment_inputs[i] for i in range(len(segment_inputs)) if i % 8 in [2, 3]]
)
video_input.change(
fn=update_timestamps,
inputs=[video_input],
outputs=[comp for pair in zip(segment_inputs[2::8], segment_inputs[3::8]) for comp in pair]
)
if segments and segments[0]["load_audio_btn"]:
segments[0]["load_audio_btn"].click(
fn=load_default_audio_fn,
inputs=[],
outputs=[segments[0]["audio"]]
)
# ==================== 启动服务 ====================
demo.launch(
server_name="0.0.0.0",
server_port=SERVER_PORT,
share=False,
show_error=True,
inbrowser=True,
)
if __name__ == "__main__":
main()