""" Quick Mode 任务服务 处理一键训练任务的业务逻辑 """ import re import uuid from datetime import datetime from pathlib import Path from typing import AsyncGenerator, Dict, Optional, Any, Tuple from project_config import settings from ..core.adapters import get_database_adapter, get_task_queue_adapter, get_storage_adapter from ..models.domain import Task, TaskStatus from ..models.schemas.task import ( QuickModeRequest, TaskResponse, TaskListResponse, InferenceOutputItem, InferenceOutputsResponse, ) # 质量预设配置 QUALITY_PRESETS = { "fast": { "sovits_epochs": 4, "gpt_epochs": 8, "description": "快速训练,约10分钟", }, "standard": { "sovits_epochs": 8, "gpt_epochs": 15, "description": "标准训练,约20分钟", }, "high": { "sovits_epochs": 16, "gpt_epochs": 30, "description": "高质量训练,约40分钟", }, } # 各语言的默认推理测试文本 DEFAULT_TARGET_TEXTS = { "zh": "这是一段测试语音合成的文本。请你用自然、清晰、不过度夸张的语气朗读,并在逗号和句号处做适当停顿:先慢一点,再稍微快一点,最后恢复正常语速。", "en": "This is a test text for speech synthesis. Please read it naturally and clearly, without exaggeration, pausing appropriately at commas and periods: start slowly, then speed up a bit, and finally return to normal pace.", "ja": "これは音声合成のテストテキストです。自然で明瞭に、大げさにならないように朗読してください。読点と句点で適切に間を置いて:最初はゆっくり、少し速く、最後は普通の速さに戻してください。", "ko": "이것은 음성 합성을 위한 테스트 텍스트입니다. 자연스럽고 명확하게, 과장하지 않고 읽어주세요. 쉼표와 마침표에서 적절히 멈추며: 먼저 천천히, 그 다음 조금 빠르게, 마지막으로 보통 속도로 돌아오세요.", "yue": "呢段係測試語音合成嘅文字。請你用自然、清楚、唔好太誇張嘅語氣讀出嚟,喺逗號同句號嗰度要適當噉停一停:開頭慢啲,跟住快少少,最後返返正常語速。", } class TaskService: """ Quick Mode 任务服务 提供一键训练任务的完整生命周期管理: - 创建任务 - 查询任务状态 - 取消任务 - 订阅进度更新 Example: >>> service = TaskService() >>> task = await service.create_quick_task(request) >>> status = await service.get_task(task.id) >>> await service.cancel_task(task.id) """ def __init__(self): """初始化服务""" self._db = None self._queue = None self._storage = None @property def db(self): """延迟获取数据库适配器""" if self._db is None: self._db = get_database_adapter() return self._db @property def queue(self): """延迟获取任务队列适配器""" if self._queue is None: self._queue = get_task_queue_adapter() return self._queue @property def storage(self): """延迟获取存储适配器""" if self._storage is None: self._storage = get_storage_adapter() return self._storage async def check_exp_name_exists(self, exp_name: str) -> bool: """ 检查实验名称是否已存在 Args: exp_name: 实验名称 Returns: 如果存在返回 True,否则返回 False """ existing_task = await self.db.get_task_by_exp_name(exp_name) return existing_task is not None async def validate_audio_file(self, audio_file_id: str) -> tuple[bool, str]: """ 验证音频文件是否存在 Args: audio_file_id: 音频文件 ID 或路径 Returns: (是否存在, 实际文件路径) """ import os # 尝试获取文件元数据 file_metadata = await self.storage.get_file_metadata(audio_file_id) if file_metadata: # 文件存储在 storage.base_path / file_id audio_file_path = str(self.storage.base_path / audio_file_id) exists = os.path.exists(audio_file_path) return exists, audio_file_path else: # 如果找不到元数据,将 audio_file_id 当作路径 exists = os.path.exists(audio_file_id) return exists, audio_file_id async def create_quick_task(self, request: QuickModeRequest) -> TaskResponse: """ 创建一键训练任务 根据请求参数和质量预设,自动配置训练参数并创建任务。 Args: request: 快速模式请求 Returns: TaskResponse: 任务响应 """ # 生成任务ID task_id = f"task-{uuid.uuid4().hex[:12]}" # 获取质量预设 quality = request.options.quality preset = QUALITY_PRESETS.get(quality, QUALITY_PRESETS["standard"]) # 验证并解析音频文件路径 audio_file_id = request.audio_file_id _, audio_file_path = await self.validate_audio_file(audio_file_id) # 构建阶段列表 stages = [ "audio_slice", "asr", "text_feature", "hubert_feature", "semantic_token", "sovits_train", "gpt_train", ] # 获取推理配置 inference_opts = request.options.inference inference_enabled = inference_opts is None or inference_opts.enabled # 如果启用推理,添加推理阶段 if inference_enabled: stages.append("inference") # 构建任务配置 config = { "exp_name": request.exp_name, "audio_file_id": audio_file_id, "input_path": audio_file_path, # 音频文件的实际路径 "version": request.options.version, "language": request.options.language, "quality": quality, # 训练参数 "total_epoch": preset["sovits_epochs"], # SoVITS epoch "sovits_epochs": preset["sovits_epochs"], "gpt_epochs": preset["gpt_epochs"], # 预训练模型路径 "bert_pretrained_dir": str(settings.BERT_PRETRAINED_DIR), "ssl_pretrained_dir": str(settings.SSL_PRETRAINED_DIR), "pretrained_s2G": str(settings.PRETRAINED_S2G), "pretrained_s2D": str(settings.PRETRAINED_S2D), "pretrained_s1": str(settings.PRETRAINED_S1), # 执行阶段 "stages": stages, } # 添加推理配置(如果启用) if inference_enabled: if inference_opts: config["ref_text"] = inference_opts.ref_text or "" config["ref_audio_path"] = inference_opts.ref_audio_path or "" config["target_text"] = inference_opts.target_text else: # 使用默认值 - 空字符串表示由推理阶段从 asr_opt/slicer_opt.list 文件解析 config["ref_text"] = "" config["ref_audio_path"] = "" # 根据语言选择默认的测试文本 language = request.options.language config["target_text"] = DEFAULT_TARGET_TEXTS.get(language, DEFAULT_TARGET_TEXTS["zh"]) # 创建 Task 领域模型 task = Task( id=task_id, exp_name=request.exp_name, config=config, status=TaskStatus.QUEUED, created_at=datetime.utcnow(), ) # 保存到数据库 await self.db.create_task(task) # 入队执行 job_id = await self.queue.enqueue(task_id, config) # 更新 job_id await self.db.update_task(task_id, {"job_id": job_id}) task.job_id = job_id return self._task_to_response(task) async def get_task(self, task_id: str) -> Optional[TaskResponse]: """ 获取任务详情 Args: task_id: 任务ID Returns: TaskResponse 或 None(不存在时) """ task = await self.db.get_task(task_id) if not task: return None return self._task_to_response(task) async def list_tasks( self, status: Optional[str] = None, limit: int = 50, offset: int = 0 ) -> TaskListResponse: """ 获取任务列表 Args: status: 按状态筛选 limit: 每页数量 offset: 偏移量 Returns: TaskListResponse """ tasks = await self.db.list_tasks(status=status, limit=limit, offset=offset) total = await self.db.count_tasks(status=status) return TaskListResponse( items=[self._task_to_response(t) for t in tasks], total=total, limit=limit, offset=offset, ) async def cancel_task(self, task_id: str) -> bool: """ 取消任务 Args: task_id: 任务ID Returns: 是否成功取消 """ # 获取任务 task = await self.db.get_task(task_id) if not task: return False # 只有排队中或运行中的任务可以取消 if task.status not in (TaskStatus.QUEUED, TaskStatus.RUNNING): return False # 如果有 job_id,尝试取消队列任务 if task.job_id: await self.queue.cancel(task.job_id) # 更新状态 await self.db.update_task(task_id, { "status": TaskStatus.CANCELLED, "completed_at": datetime.utcnow(), "message": "任务已取消", }) return True async def subscribe_progress( self, task_id: str ) -> AsyncGenerator[Dict[str, Any], None]: """ 订阅任务进度(SSE 流) Args: task_id: 任务ID Yields: 进度信息字典 """ # 检查任务是否存在 task = await self.db.get_task(task_id) if not task: yield {"type": "error", "message": "任务不存在"} return # 如果任务已结束,直接返回最终状态 if task.status in (TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED): yield { "type": "final", "status": task.status.value, "message": task.message or task.error_message, "progress": task.progress, } return # 订阅进度更新 async for progress in self.queue.subscribe_progress(task_id): yield progress # 检查是否为终态 if progress.get("status") in ("completed", "failed", "cancelled"): break async def get_inference_outputs(self, task_id: str) -> Optional[InferenceOutputsResponse]: """ 获取任务的推理输出列表 扫描 logs/{exp_name}/inference/ 目录,返回所有推理生成的音频文件元信息。 Args: task_id: 任务 ID Returns: InferenceOutputsResponse 或 None(任务不存在时) """ # 获取任务 task = await self.db.get_task(task_id) if not task: return None exp_name = task.exp_name inference_dir = Path(settings.EXP_ROOT) / exp_name / "inference" # 从 task.config 获取推理配置 ref_text = task.config.get("ref_text", "") ref_audio_path = task.config.get("ref_audio_path", "") target_text = task.config.get("target_text", "") # 如果为空,从 .list 文件解析 if not ref_text or not ref_audio_path: parsed_audio, parsed_text = self._parse_list_file(exp_name) ref_audio_path = parsed_audio ref_text = parsed_text # 获取模型版本以确定权重目录 version = task.config.get("version", "v2") gpt_weight_dir = self._get_gpt_weight_dir(version) sovits_weight_dir = self._get_sovits_weight_dir(version) outputs = [] if inference_dir.exists() and inference_dir.is_dir(): # 扫描目录中的所有 .wav 文件 for file_path in inference_dir.glob("*.wav"): filename = file_path.name # 解析文件名获取模型信息 # 格式: {exp_name}_gpt-{gpt_name}_sovits-{sovits_name}.wav gpt_model, sovits_model = self._parse_inference_filename(filename, exp_name) # 构建模型完整路径 gpt_path = str(Path(settings.EXP_ROOT) / exp_name / gpt_weight_dir / f"{gpt_model}.ckpt") sovits_path = str(Path(settings.EXP_ROOT) / exp_name / sovits_weight_dir / f"{sovits_model}.pth") # 获取文件信息 stat = file_path.stat() outputs.append(InferenceOutputItem( filename=filename, gpt_model=gpt_model, sovits_model=sovits_model, gpt_path=gpt_path, sovits_path=sovits_path, file_path=str(file_path.relative_to(settings.PROJECT_ROOT)), size_bytes=stat.st_size, created_at=datetime.fromtimestamp(stat.st_ctime), )) return InferenceOutputsResponse( task_id=task_id, exp_name=exp_name, ref_text=ref_text, ref_audio_path=ref_audio_path, target_text=target_text, outputs=outputs, total=len(outputs), ) async def download_inference_output( self, task_id: str, filename: str ) -> Optional[Tuple[bytes, str, str]]: """ 下载指定的推理输出文件 Args: task_id: 任务 ID filename: 文件名 Returns: (文件内容, 文件名, content_type) 或 None(不存在时) """ # 获取任务 task = await self.db.get_task(task_id) if not task: return None exp_name = task.exp_name file_path = Path(settings.EXP_ROOT) / exp_name / "inference" / filename # 安全检查:确保文件在预期目录内 try: file_path = file_path.resolve() expected_parent = (Path(settings.EXP_ROOT) / exp_name / "inference").resolve() if not str(file_path).startswith(str(expected_parent)): return None except (ValueError, OSError): return None if not file_path.exists() or not file_path.is_file(): return None # 读取文件内容 with open(file_path, "rb") as f: file_data = f.read() return file_data, filename, "audio/wav" async def download_file( self, task_id: str, file_type: str, filename: str ) -> Optional[Tuple[bytes, str, str]]: """ 下载指定类型的文件 Args: task_id: 任务 ID file_type: 文件类型 (output/ref_audio/gpt_model/sovits_model) filename: 文件名 Returns: (文件内容, 文件名, content_type) 或 None(不存在时) """ # 获取任务 task = await self.db.get_task(task_id) if not task: return None exp_name = task.exp_name version = task.config.get("version", "v2") # 根据文件类型确定路径和 content_type if file_type == "output": file_path = Path(settings.EXP_ROOT) / exp_name / "inference" / filename content_type = "audio/wav" elif file_type == "ref_audio": # ref_audio 使用完整路径(filename 参数实际上是完整路径) file_path = Path(filename) content_type = "audio/wav" elif file_type == "gpt_model": gpt_dir = self._get_gpt_weight_dir(version) file_path = Path(settings.EXP_ROOT) / exp_name / gpt_dir / filename content_type = "application/octet-stream" elif file_type == "sovits_model": sovits_dir = self._get_sovits_weight_dir(version) file_path = Path(settings.EXP_ROOT) / exp_name / sovits_dir / filename content_type = "application/octet-stream" else: return None # 安全检查:确保文件路径有效 try: file_path = file_path.resolve() except (ValueError, OSError): return None if not file_path.exists() or not file_path.is_file(): return None # 读取文件内容 with open(file_path, "rb") as f: file_data = f.read() # 使用文件名(不含路径)作为下载文件名 download_filename = file_path.name return file_data, download_filename, content_type def _parse_inference_filename(self, filename: str, exp_name: str) -> Tuple[str, str]: """ 解析推理输出文件名,提取 GPT 和 SoVITS 模型名称 文件名格式: {exp_name}_gpt-{gpt_name}_sovits-{sovits_name}.wav Args: filename: 文件名 exp_name: 实验名称 Returns: (gpt_model, sovits_model) """ # 移除扩展名 name = filename.rsplit(".", 1)[0] if "." in filename else filename # 尝试解析模型名称 # 格式: {exp_name}_gpt-{gpt_name}_sovits-{sovits_name} pattern = rf"^{re.escape(exp_name)}_gpt-(.+)_sovits-(.+)$" match = re.match(pattern, name) if match: return match.group(1), match.group(2) # 备用解析:尝试匹配 gpt- 和 sovits- 部分 gpt_match = re.search(r"gpt-([^_]+)", name) sovits_match = re.search(r"sovits-([^_]+)", name) gpt_model = gpt_match.group(1) if gpt_match else "unknown" sovits_model = sovits_match.group(1) if sovits_match else "unknown" return gpt_model, sovits_model def _parse_list_file(self, exp_name: str) -> Tuple[str, str]: """ 从 asr_opt/slicer_opt.list 解析第一行获取 ref_audio_path 和 ref_text Args: exp_name: 实验名称 Returns: (ref_audio_path, ref_text) 元组,解析失败返回空字符串 """ list_path = Path(settings.EXP_ROOT) / exp_name / 'asr_opt' / 'slicer_opt.list' if not list_path.exists(): return "", "" with open(list_path, 'r', encoding='utf-8') as f: first_line = f.readline().strip() if not first_line: return "", "" # 格式: {音频路径}|{文件夹名}|{语言}|{识别文本} parts = first_line.split('|') if len(parts) >= 4: return parts[0], parts[3] return "", "" def _get_gpt_weight_dir(self, version: str) -> str: """根据模型版本获取 GPT 权重目录名""" version_to_dir = { "v1": "GPT_weights", "v2": "GPT_weights_v2", "v3": "GPT_weights_v3", "v4": "GPT_weights_v4", "v2Pro": "GPT_weights_v2Pro", "v2ProPlus": "GPT_weights_v2ProPlus", } return version_to_dir.get(version, "GPT_weights_v2") def _get_sovits_weight_dir(self, version: str) -> str: """根据模型版本获取 SoVITS 权重目录名""" version_to_dir = { "v1": "SoVITS_weights", "v2": "SoVITS_weights_v2", "v3": "SoVITS_weights_v3", "v4": "SoVITS_weights_v4", "v2Pro": "SoVITS_weights_v2Pro", "v2ProPlus": "SoVITS_weights_v2ProPlus", } return version_to_dir.get(version, "SoVITS_weights_v2") def _task_to_response(self, task: Task) -> TaskResponse: """将 Task 领域模型转换为 TaskResponse""" return TaskResponse( id=task.id, exp_name=task.exp_name, status=task.status.value if isinstance(task.status, TaskStatus) else task.status, current_stage=task.current_stage, progress=task.stage_progress, overall_progress=task.progress, message=task.message, error_message=task.error_message, created_at=task.created_at, started_at=task.started_at, completed_at=task.completed_at, )