Spaces:
Sleeping
Sleeping
| """Gradio用户界面模块 | |
| 提供基于Gradio的Web界面,支持文件上传、进度显示和结果展示。 | |
| """ | |
| import asyncio | |
| import json | |
| from pathlib import Path | |
| from typing import Dict, List, Optional, Tuple, Any | |
| import gradio as gr | |
| import pandas as pd | |
| from ..core.config import get_config | |
| from ..core.task_manager import get_task_manager, TaskStatus, TaskPriority | |
| from ..utils.logger import get_task_logger | |
| from ..services.file_validator import get_file_validator | |
| class GradioInterface: | |
| """Gradio界面管理器""" | |
| def __init__(self): | |
| """初始化Gradio界面""" | |
| self.config = get_config() | |
| self.task_manager = get_task_manager() | |
| self.file_validator = get_file_validator() | |
| self.logger = get_task_logger(logger_name="transcript_service.gradio") | |
| # 当前任务ID | |
| self.current_task_id = None | |
| # 创建界面 | |
| self.interface = self._create_interface() | |
| # 注册任务状态回调 | |
| self.task_manager.add_status_callback(self._on_task_status_change) | |
| def _create_interface(self) -> gr.Blocks: | |
| """创建Gradio界面""" | |
| # 获取支持的格式信息 | |
| supported_formats = self.file_validator.get_supported_formats() | |
| with gr.Blocks( | |
| title="音频转文字服务", | |
| theme=gr.themes.Soft(), | |
| css=""" | |
| .main-container { max-width: 1000px; margin: 0 auto; } | |
| .upload-area { border: 2px dashed #ccc; border-radius: 10px; padding: 20px; text-align: center; } | |
| .result-area { margin-top: 20px; } | |
| .status-simple { font-size: 16px; font-weight: bold; } | |
| """ | |
| ) as interface: | |
| # 简洁标题 | |
| gr.Markdown("# 🎵 音频转文字服务") | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| # 文件上传区 | |
| file_upload = gr.File( | |
| label="📁 选择音频文件(支持多文件)", | |
| file_count="multiple", | |
| file_types=list(supported_formats['extensions']), | |
| height=120 | |
| ) | |
| # 简化的配置区 | |
| with gr.Row(): | |
| # 任务优先级 | |
| priority_select = gr.Radio( | |
| label="优先级", | |
| choices=[("普通", "NORMAL"), ("高优先级", "HIGH")], | |
| value="NORMAL" | |
| ) | |
| # 参数设置区(默认隐藏) | |
| with gr.Accordion("⚙️ 转录参数设置", open=False) as params_section: | |
| # 语言选择 | |
| language_select = gr.CheckboxGroup( | |
| label="识别语言", | |
| choices=[ | |
| ("中文", "zh"), ("英文", "en"), ("日语", "ja"), | |
| ("粤语", "yue"), ("韩语", "ko"), ("德语", "de"), | |
| ("法语", "fr"), ("俄语", "ru") | |
| ], | |
| value=["zh", "en"] | |
| ) | |
| with gr.Row(): | |
| # 基础选项 | |
| disfluency_removal = gr.Checkbox( | |
| label="过滤语气词", | |
| value=True | |
| ) | |
| timestamp_alignment = gr.Checkbox( | |
| label="时间戳校准", | |
| value=True | |
| ) | |
| diarization_enabled = gr.Checkbox( | |
| label="说话人分离", | |
| value=True | |
| ) | |
| with gr.Row(): | |
| speaker_count = gr.Number( | |
| label="说话人数量(可选)", | |
| value=None, | |
| minimum=None, | |
| maximum=100, | |
| step=1, | |
| info="留空则自动判断,如需指定请输入2-100之间的数值" | |
| ) | |
| channel_select = gr.Textbox( | |
| label="音轨索引", | |
| value="0", | |
| info="多音轨文件的音轨索引,用逗号分隔" | |
| ) | |
| # 高级选项(更深层折叠) | |
| with gr.Accordion("高级选项", open=False): | |
| vocabulary_id = gr.Textbox( | |
| label="热词ID v2", | |
| value="", | |
| info="v2模型的热词ID" | |
| ) | |
| phrase_id = gr.Textbox( | |
| label="热词ID v1", | |
| value="", | |
| info="v1模型的热词ID" | |
| ) | |
| special_word_filter = gr.Textbox( | |
| label="敏感词过滤配置", | |
| value="", | |
| lines=2, | |
| placeholder='JSON格式配置', | |
| info="敏感词过滤的JSON配置" | |
| ) | |
| # 控制按钮 | |
| with gr.Row(): | |
| start_btn = gr.Button("🚀 开始转录", variant="primary", size="lg") | |
| cancel_btn = gr.Button("❌ 取消", variant="secondary") | |
| clear_btn = gr.Button("🗑️ 清空", variant="secondary") | |
| with gr.Column(scale=2): | |
| # 简化的状态显示 | |
| status_text = gr.Textbox( | |
| label="📊 当前状态", | |
| value="等待上传文件...", | |
| interactive=False, | |
| elem_classes=["status-simple"] | |
| ) | |
| # 转录结果 | |
| result_text = gr.Textbox( | |
| label="📝 转录结果", | |
| placeholder="转录结果将在这里显示...", | |
| lines=12, | |
| max_lines=20, | |
| show_copy_button=True, | |
| elem_classes=["result-area"] | |
| ) | |
| # 文件统计表格 | |
| stats_df = gr.Dataframe( | |
| headers=["文件名", "时长", "文本长度", "置信度"], | |
| datatype=["str", "str", "number", "number"], | |
| label="📈 处理统计", | |
| visible=False | |
| ) | |
| # 折叠的详细信息区域 | |
| with gr.Accordion("📋 详细信息", open=False) as detail_section: | |
| with gr.Tabs(): | |
| with gr.Tab("系统信息"): | |
| system_info = gr.JSON( | |
| label="服务状态", | |
| value=self._get_system_info() | |
| ) | |
| format_info = gr.JSON( | |
| label="支持格式", | |
| value=supported_formats | |
| ) | |
| with gr.Tab("任务信息"): | |
| task_info = gr.JSON( | |
| label="当前任务", | |
| value={} | |
| ) | |
| with gr.Tab("完整结果"): | |
| result_json = gr.JSON( | |
| label="JSON结果", | |
| value={} | |
| ) | |
| with gr.Tab("处理日志"): | |
| log_text = gr.Textbox( | |
| label="详细日志", | |
| lines=8, | |
| max_lines=12, | |
| interactive=False, | |
| show_copy_button=True | |
| ) | |
| log_download = gr.File( | |
| label="下载日志文件", | |
| visible=False | |
| ) | |
| # 添加手动刷新按钮 | |
| with gr.Row(): | |
| refresh_btn = gr.Button("🔄 刷新状态", variant="secondary", size="sm") | |
| refresh_btn.click( | |
| fn=self._update_interface, | |
| outputs=[status_text, task_info, result_text, result_json, stats_df, system_info, log_text] | |
| ) | |
| # 事件处理 | |
| start_btn.click( | |
| fn=self._process_files, | |
| inputs=[ | |
| file_upload, priority_select, language_select, | |
| disfluency_removal, timestamp_alignment, diarization_enabled, | |
| speaker_count, channel_select, vocabulary_id, | |
| phrase_id, special_word_filter | |
| ], | |
| outputs=[status_text, task_info, log_text] | |
| ) | |
| cancel_btn.click( | |
| fn=self._cancel_current_task, | |
| outputs=[status_text, task_info] | |
| ) | |
| clear_btn.click( | |
| fn=self._clear_interface, | |
| outputs=[file_upload, result_text, result_json, stats_df, log_text, status_text, task_info] | |
| ) | |
| # 定时更新 | |
| interface.load( | |
| fn=self._update_interface, | |
| outputs=[status_text, task_info, result_text, result_json, stats_df, system_info, log_text] | |
| ) | |
| return interface | |
| def _get_custom_css(self) -> str: | |
| """获取自定义CSS样式""" | |
| return """ | |
| .gradio-container { | |
| max-width: 1200px !important; | |
| } | |
| .gr-button-primary { | |
| background: linear-gradient(45deg, #FF6B6B, #4ECDC4) !important; | |
| border: none !important; | |
| } | |
| .gr-button-primary:hover { | |
| transform: translateY(-2px) !important; | |
| box-shadow: 0 4px 12px rgba(0,0,0,0.15) !important; | |
| } | |
| .progress-bar { | |
| background: linear-gradient(90deg, #FF6B6B, #4ECDC4) !important; | |
| } | |
| """ | |
| def _get_system_info(self) -> Dict: | |
| """获取系统信息""" | |
| stats = self.task_manager.get_statistics() | |
| return { | |
| "服务状态": "运行中", | |
| "当前任务数": stats['total_tasks'], | |
| "待处理": stats['pending'], | |
| "处理中": stats['validating'] + stats['uploading'] + stats['transcribing'], | |
| "已完成": stats['completed'], | |
| "失败": stats['failed'], | |
| "队列大小": stats['queue_size'] | |
| } | |
| def _get_timestamp(self) -> str: | |
| """获取当前时间戳""" | |
| from datetime import datetime | |
| return datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| async def _process_files( | |
| self, | |
| files: List, | |
| priority: str, | |
| languages: List[str], | |
| disfluency_removal: bool, | |
| timestamp_alignment: bool, | |
| diarization_enabled: bool, | |
| speaker_count: Optional[int] | None, | |
| channel_id: str, | |
| vocabulary_id: str, | |
| phrase_id: str, | |
| special_word_filter: str | |
| ) -> Tuple[str, Dict, str]: | |
| """处理上传的文件 | |
| Args: | |
| files: 上传的文件列表 | |
| languages: 选择的语言 | |
| priority: 任务优先级 | |
| channel_id: 音轨索引 | |
| disfluency_removal: 是否过滤语气词 | |
| timestamp_alignment: 是否启用时间戳校准 | |
| diarization_enabled: 是否启用说话人分离 | |
| speaker_count: 说话人数量参考值 | |
| vocabulary_id: 热词ID v2 | |
| phrase_id: 热词ID v1 | |
| special_word_filter: 敏感词过滤配置 | |
| Returns: | |
| (状态信息, 任务信息, 日志信息) | |
| """ | |
| try: | |
| if not files: | |
| return "请先上传音频文件", {}, "错误: 未选择任何文件" | |
| # 记录详细日志 | |
| log_messages = [] | |
| log_messages.append(f"[{self._get_timestamp()}] 开始处理文件上传请求") | |
| log_messages.append(f"[{self._get_timestamp()}] 接收到 {len(files)} 个文件") | |
| # 转换文件路径 | |
| file_paths = [Path(f.name) for f in files] | |
| log_messages.append(f"[{self._get_timestamp()}] 转换文件路径完成") | |
| # 显示文件信息 | |
| for i, file_path in enumerate(file_paths): | |
| try: | |
| file_size = file_path.stat().st_size | |
| log_messages.append(f"[{self._get_timestamp()}] 文件 {i+1}: {file_path.name} (大小: {file_size} 字节)") | |
| except Exception as e: | |
| log_messages.append(f"[{self._get_timestamp()}] 文件 {i+1}: {file_path.name} (无法获取文件信息: {str(e)})") | |
| # 解析音轨参数 | |
| try: | |
| channel_list = [int(x.strip()) for x in channel_id.split(',') if x.strip()] | |
| except ValueError: | |
| channel_list = [0] # 默认为第一条音轨 | |
| # 验证说话人数量参数 | |
| validated_speaker_count = None | |
| if speaker_count is not None: | |
| if isinstance(speaker_count, (int, float)) and speaker_count >= 2 and speaker_count <= 100: | |
| validated_speaker_count = int(speaker_count) | |
| else: | |
| log_messages.append(f"[{self._get_timestamp()}] 警告: 说话人数量无效({speaker_count}),将使用自动判断") | |
| # 解析敏感词过滤参数 | |
| special_filter = None | |
| if special_word_filter.strip(): | |
| try: | |
| special_filter = json.loads(special_word_filter) | |
| except json.JSONDecodeError as e: | |
| log_messages.append(f"[{self._get_timestamp()}] 警告: 敏感词过滤配置格式错误,将使用默认设置") | |
| # 创建任务 | |
| task_priority = TaskPriority.HIGH if priority == "HIGH" else TaskPriority.NORMAL | |
| # 准备元数据,包含所有Paraformer参数 | |
| metadata = { | |
| "languages": languages, | |
| "file_count": len(file_paths), | |
| "paraformer_params": { | |
| "language_hints": languages, | |
| "channel_id": channel_list, | |
| "disfluency_removal_enabled": disfluency_removal, | |
| "timestamp_alignment_enabled": timestamp_alignment, | |
| "diarization_enabled": diarization_enabled, | |
| "speaker_count": validated_speaker_count, | |
| "vocabulary_id": vocabulary_id.strip() if vocabulary_id.strip() else None, | |
| "phrase_id": phrase_id.strip() if phrase_id.strip() else None, | |
| "special_word_filter": json.dumps(special_filter) if special_filter else None | |
| } | |
| } | |
| log_messages.append(f"[{self._get_timestamp()}] 创建任务,优先级: {task_priority.value}") | |
| log_messages.append(f"[{self._get_timestamp()}] 选择语言: {', '.join(languages) if languages else '自动识别'}") | |
| self.current_task_id = await self.task_manager.create_task( | |
| file_paths=file_paths, | |
| priority=task_priority, | |
| metadata=metadata | |
| ) | |
| task = self.task_manager.get_task(self.current_task_id) | |
| log_messages.append(f"[{self._get_timestamp()}] 任务创建成功,任务ID: {self.current_task_id}") | |
| return ( | |
| f"任务已创建: {self.current_task_id}", | |
| task.to_dict() if task else {}, | |
| "\n".join(log_messages) + f"\n开始处理 {len(file_paths)} 个文件...\n" | |
| ) | |
| except Exception as e: | |
| error_msg = f"创建任务失败: {str(e)}" | |
| self.logger.exception(error_msg) | |
| return error_msg, {}, f"错误: {error_msg}\n" | |
| def _cancel_current_task(self) -> Tuple[str, Dict]: | |
| """取消当前任务""" | |
| if not self.current_task_id: | |
| return "没有正在执行的任务", {} | |
| success = asyncio.create_task( | |
| self.task_manager.cancel_task(self.current_task_id) | |
| ) | |
| if success: | |
| return f"任务 {self.current_task_id} 已取消", {} | |
| else: | |
| return "取消任务失败", {} | |
| def _clear_interface(self) -> Tuple[None, str, Dict, List, str, str, Dict]: | |
| """清空界面""" | |
| self.current_task_id = None | |
| return ( | |
| None, # file_upload | |
| "", # result_text | |
| {}, # result_json | |
| [], # stats_df | |
| "", # log_text | |
| "界面已清空,等待上传文件...", # status_text | |
| {} # task_info | |
| ) | |
| def _update_interface(self) -> Tuple[str, Dict, str, Dict, List, Dict, str]: | |
| """更新界面状态""" | |
| # 更新当前任务状态 | |
| status_text = "等待上传文件..." | |
| task_info = {} | |
| result_text = "" | |
| result_json = {} | |
| stats_data = [] | |
| log_text = "" | |
| if self.current_task_id: | |
| task = self.task_manager.get_task(self.current_task_id) | |
| if task: | |
| task_info = task.to_dict() | |
| status_text = f"[{task.status.value}] {task.progress.message}" | |
| # 收集详细日志 | |
| log_text = self._collect_task_logs(task) | |
| # 如果任务完成,显示结果 | |
| if task.status == TaskStatus.COMPLETED: | |
| self.logger.debug(f"任务已完成,检查转录结果: {task.result.transcription_results}") | |
| if task.result.transcription_results: | |
| result_json = task.result.transcription_results | |
| # 提取转录文本 | |
| transcriptions = result_json.get('transcriptions', []) | |
| self.logger.debug(f"转录结果: {transcriptions}") | |
| result_text = "\n\n".join([ | |
| f"文件: {t.get('file_url', '').split('/')[-1]}\n{t.get('text', '')}" | |
| for t in transcriptions if t.get('text') | |
| ]) | |
| # 生成统计表格 | |
| stats_data = [] | |
| for t in transcriptions: | |
| if 'error' not in t: | |
| stats_data.append([ | |
| t.get('file_url', '').split('/')[-1], | |
| f"{t.get('duration', 0):.1f}s", | |
| len(t.get('text', '')), | |
| t.get('language', 'unknown'), | |
| round(t.get('confidence', 0), 3) | |
| ]) | |
| else: | |
| self.logger.debug("任务已完成但没有转录结果") | |
| elif task.status == TaskStatus.FAILED: | |
| # 如果任务失败,显示错误信息 | |
| if task.result and task.result.error_message: | |
| log_text += f"\n[{self._get_timestamp()}] 任务失败: {task.result.error_message}" | |
| # 更新系统信息 | |
| system_info = self._get_system_info() | |
| return status_text, task_info, result_text, result_json, stats_data, system_info, log_text | |
| def _collect_task_logs(self, task) -> str: | |
| """收集任务的详细日志 | |
| Args: | |
| task: 任务对象 | |
| Returns: | |
| 格式化的日志字符串 | |
| """ | |
| if not task: | |
| return "无任务信息" | |
| log_lines = [] | |
| log_lines.append(f"[{self._get_timestamp()}] 任务ID: {task.id}") | |
| log_lines.append(f"[{self._get_timestamp()}] 任务状态: {task.status.value}") | |
| log_lines.append(f"[{self._get_timestamp()}] 任务创建时间: {task.created_at}") | |
| # 添加进度信息 | |
| if task.progress: | |
| log_lines.append(f"[{self._get_timestamp()}] 进度信息: {task.progress.message}") | |
| # TaskProgress对象没有details属性,只使用message | |
| # 添加文件信息 | |
| if hasattr(task, 'file_paths') and task.file_paths: | |
| log_lines.append(f"[{self._get_timestamp()}] 文件列表:") | |
| for i, file_path in enumerate(task.file_paths): | |
| try: | |
| file_size = file_path.stat().st_size | |
| log_lines.append(f" {i+1}. {file_path.name} ({file_size} bytes)") | |
| except Exception as e: | |
| log_lines.append(f" {i+1}. {file_path.name} (无法获取文件信息: {str(e)})") | |
| # 添加结果信息(如果任务已完成) | |
| if task.status == TaskStatus.COMPLETED and task.result: | |
| log_lines.append(f"[{self._get_timestamp()}] 任务完成时间: {task.completed_at}") | |
| if hasattr(task.result, 'transcription_results') and task.result.transcription_results: | |
| transcriptions = task.result.transcription_results.get('transcriptions', []) | |
| log_lines.append(f"[{self._get_timestamp()}] 转录结果: {len(transcriptions)} 个文件") | |
| # 添加错误信息(如果有的话) | |
| # Task对象没有error属性,错误信息在result中 | |
| return "\n".join(log_lines) | |
| def _on_task_status_change(self, task): | |
| """任务状态变化回调""" | |
| self.logger.debug(f"任务状态变化: {task.id} -> {task.status.value}") | |
| # 当任务状态变化时,不直接更新界面,而是依赖定时更新机制 | |
| # Gradio的回调中不能直接更新界面组件 | |
| def launch(self, **kwargs): | |
| """启动Gradio界面""" | |
| default_kwargs = { | |
| 'server_name': '0.0.0.0', # 改为0.0.0.0以允许外部访问 | |
| 'server_port': self.config.app.port, | |
| 'share': True, # 开启分享链接 | |
| 'debug': self.config.app.debug, | |
| 'show_error': True, | |
| 'quiet': not self.config.app.debug | |
| } | |
| default_kwargs.update(kwargs) | |
| self.logger.info(f"启动Gradio界面: http://{default_kwargs['server_name']}:{default_kwargs['server_port']}") | |
| return self.interface.launch(**default_kwargs) | |
| # 全局界面实例 | |
| gradio_interface = GradioInterface() | |
| def get_gradio_interface() -> GradioInterface: | |
| """获取Gradio界面实例 | |
| Returns: | |
| Gradio界面实例 | |
| """ | |
| return gradio_interface | |
| def create_demo_interface() -> gr.Blocks: | |
| """创建演示界面 | |
| Returns: | |
| Gradio界面对象 | |
| """ | |
| return gradio_interface.interface |