liumaolin
refactor(config): centralize configuration management in `project_config`
8f68d0a
"""
Quick Mode 任务服务
处理一键训练任务的业务逻辑
"""
import re
import uuid
from datetime import datetime
from pathlib import Path
from typing import AsyncGenerator, Dict, Optional, Any, Tuple
from project_config import settings
from ..core.adapters import get_database_adapter, get_task_queue_adapter, get_storage_adapter
from ..models.domain import Task, TaskStatus
from ..models.schemas.task import (
QuickModeRequest,
TaskResponse,
TaskListResponse,
InferenceOutputItem,
InferenceOutputsResponse,
)
# 质量预设配置
QUALITY_PRESETS = {
"fast": {
"sovits_epochs": 4,
"gpt_epochs": 8,
"description": "快速训练,约10分钟",
},
"standard": {
"sovits_epochs": 8,
"gpt_epochs": 15,
"description": "标准训练,约20分钟",
},
"high": {
"sovits_epochs": 16,
"gpt_epochs": 30,
"description": "高质量训练,约40分钟",
},
}
# 各语言的默认推理测试文本
DEFAULT_TARGET_TEXTS = {
"zh": "这是一段测试语音合成的文本。请你用自然、清晰、不过度夸张的语气朗读,并在逗号和句号处做适当停顿:先慢一点,再稍微快一点,最后恢复正常语速。",
"en": "This is a test text for speech synthesis. Please read it naturally and clearly, without exaggeration, pausing appropriately at commas and periods: start slowly, then speed up a bit, and finally return to normal pace.",
"ja": "これは音声合成のテストテキストです。自然で明瞭に、大げさにならないように朗読してください。読点と句点で適切に間を置いて:最初はゆっくり、少し速く、最後は普通の速さに戻してください。",
"ko": "이것은 음성 합성을 위한 테스트 텍스트입니다. 자연스럽고 명확하게, 과장하지 않고 읽어주세요. 쉼표와 마침표에서 적절히 멈추며: 먼저 천천히, 그 다음 조금 빠르게, 마지막으로 보통 속도로 돌아오세요.",
"yue": "呢段係測試語音合成嘅文字。請你用自然、清楚、唔好太誇張嘅語氣讀出嚟,喺逗號同句號嗰度要適當噉停一停:開頭慢啲,跟住快少少,最後返返正常語速。",
}
class TaskService:
"""
Quick Mode 任务服务
提供一键训练任务的完整生命周期管理:
- 创建任务
- 查询任务状态
- 取消任务
- 订阅进度更新
Example:
>>> service = TaskService()
>>> task = await service.create_quick_task(request)
>>> status = await service.get_task(task.id)
>>> await service.cancel_task(task.id)
"""
def __init__(self):
"""初始化服务"""
self._db = None
self._queue = None
self._storage = None
@property
def db(self):
"""延迟获取数据库适配器"""
if self._db is None:
self._db = get_database_adapter()
return self._db
@property
def queue(self):
"""延迟获取任务队列适配器"""
if self._queue is None:
self._queue = get_task_queue_adapter()
return self._queue
@property
def storage(self):
"""延迟获取存储适配器"""
if self._storage is None:
self._storage = get_storage_adapter()
return self._storage
async def check_exp_name_exists(self, exp_name: str) -> bool:
"""
检查实验名称是否已存在
Args:
exp_name: 实验名称
Returns:
如果存在返回 True,否则返回 False
"""
existing_task = await self.db.get_task_by_exp_name(exp_name)
return existing_task is not None
async def validate_audio_file(self, audio_file_id: str) -> tuple[bool, str]:
"""
验证音频文件是否存在
Args:
audio_file_id: 音频文件 ID 或路径
Returns:
(是否存在, 实际文件路径)
"""
import os
# 尝试获取文件元数据
file_metadata = await self.storage.get_file_metadata(audio_file_id)
if file_metadata:
# 文件存储在 storage.base_path / file_id
audio_file_path = str(self.storage.base_path / audio_file_id)
exists = os.path.exists(audio_file_path)
return exists, audio_file_path
else:
# 如果找不到元数据,将 audio_file_id 当作路径
exists = os.path.exists(audio_file_id)
return exists, audio_file_id
async def create_quick_task(self, request: QuickModeRequest) -> TaskResponse:
"""
创建一键训练任务
根据请求参数和质量预设,自动配置训练参数并创建任务。
Args:
request: 快速模式请求
Returns:
TaskResponse: 任务响应
"""
# 生成任务ID
task_id = f"task-{uuid.uuid4().hex[:12]}"
# 获取质量预设
quality = request.options.quality
preset = QUALITY_PRESETS.get(quality, QUALITY_PRESETS["standard"])
# 验证并解析音频文件路径
audio_file_id = request.audio_file_id
_, audio_file_path = await self.validate_audio_file(audio_file_id)
# 构建阶段列表
stages = [
"audio_slice",
"asr",
"text_feature",
"hubert_feature",
"semantic_token",
"sovits_train",
"gpt_train",
]
# 获取推理配置
inference_opts = request.options.inference
inference_enabled = inference_opts is None or inference_opts.enabled
# 如果启用推理,添加推理阶段
if inference_enabled:
stages.append("inference")
# 构建任务配置
config = {
"exp_name": request.exp_name,
"audio_file_id": audio_file_id,
"input_path": audio_file_path, # 音频文件的实际路径
"version": request.options.version,
"language": request.options.language,
"quality": quality,
# 训练参数
"total_epoch": preset["sovits_epochs"], # SoVITS epoch
"sovits_epochs": preset["sovits_epochs"],
"gpt_epochs": preset["gpt_epochs"],
# 预训练模型路径
"bert_pretrained_dir": str(settings.BERT_PRETRAINED_DIR),
"ssl_pretrained_dir": str(settings.SSL_PRETRAINED_DIR),
"pretrained_s2G": str(settings.PRETRAINED_S2G),
"pretrained_s2D": str(settings.PRETRAINED_S2D),
"pretrained_s1": str(settings.PRETRAINED_S1),
# 执行阶段
"stages": stages,
}
# 添加推理配置(如果启用)
if inference_enabled:
if inference_opts:
config["ref_text"] = inference_opts.ref_text or ""
config["ref_audio_path"] = inference_opts.ref_audio_path or ""
config["target_text"] = inference_opts.target_text
else:
# 使用默认值 - 空字符串表示由推理阶段从 asr_opt/slicer_opt.list 文件解析
config["ref_text"] = ""
config["ref_audio_path"] = ""
# 根据语言选择默认的测试文本
language = request.options.language
config["target_text"] = DEFAULT_TARGET_TEXTS.get(language, DEFAULT_TARGET_TEXTS["zh"])
# 创建 Task 领域模型
task = Task(
id=task_id,
exp_name=request.exp_name,
config=config,
status=TaskStatus.QUEUED,
created_at=datetime.utcnow(),
)
# 保存到数据库
await self.db.create_task(task)
# 入队执行
job_id = await self.queue.enqueue(task_id, config)
# 更新 job_id
await self.db.update_task(task_id, {"job_id": job_id})
task.job_id = job_id
return self._task_to_response(task)
async def get_task(self, task_id: str) -> Optional[TaskResponse]:
"""
获取任务详情
Args:
task_id: 任务ID
Returns:
TaskResponse 或 None(不存在时)
"""
task = await self.db.get_task(task_id)
if not task:
return None
return self._task_to_response(task)
async def list_tasks(
self,
status: Optional[str] = None,
limit: int = 50,
offset: int = 0
) -> TaskListResponse:
"""
获取任务列表
Args:
status: 按状态筛选
limit: 每页数量
offset: 偏移量
Returns:
TaskListResponse
"""
tasks = await self.db.list_tasks(status=status, limit=limit, offset=offset)
total = await self.db.count_tasks(status=status)
return TaskListResponse(
items=[self._task_to_response(t) for t in tasks],
total=total,
limit=limit,
offset=offset,
)
async def cancel_task(self, task_id: str) -> bool:
"""
取消任务
Args:
task_id: 任务ID
Returns:
是否成功取消
"""
# 获取任务
task = await self.db.get_task(task_id)
if not task:
return False
# 只有排队中或运行中的任务可以取消
if task.status not in (TaskStatus.QUEUED, TaskStatus.RUNNING):
return False
# 如果有 job_id,尝试取消队列任务
if task.job_id:
await self.queue.cancel(task.job_id)
# 更新状态
await self.db.update_task(task_id, {
"status": TaskStatus.CANCELLED,
"completed_at": datetime.utcnow(),
"message": "任务已取消",
})
return True
async def subscribe_progress(
self,
task_id: str
) -> AsyncGenerator[Dict[str, Any], None]:
"""
订阅任务进度(SSE 流)
Args:
task_id: 任务ID
Yields:
进度信息字典
"""
# 检查任务是否存在
task = await self.db.get_task(task_id)
if not task:
yield {"type": "error", "message": "任务不存在"}
return
# 如果任务已结束,直接返回最终状态
if task.status in (TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED):
yield {
"type": "final",
"status": task.status.value,
"message": task.message or task.error_message,
"progress": task.progress,
}
return
# 订阅进度更新
async for progress in self.queue.subscribe_progress(task_id):
yield progress
# 检查是否为终态
if progress.get("status") in ("completed", "failed", "cancelled"):
break
async def get_inference_outputs(self, task_id: str) -> Optional[InferenceOutputsResponse]:
"""
获取任务的推理输出列表
扫描 logs/{exp_name}/inference/ 目录,返回所有推理生成的音频文件元信息。
Args:
task_id: 任务 ID
Returns:
InferenceOutputsResponse 或 None(任务不存在时)
"""
# 获取任务
task = await self.db.get_task(task_id)
if not task:
return None
exp_name = task.exp_name
inference_dir = Path(settings.EXP_ROOT) / exp_name / "inference"
# 从 task.config 获取推理配置
ref_text = task.config.get("ref_text", "")
ref_audio_path = task.config.get("ref_audio_path", "")
target_text = task.config.get("target_text", "")
# 如果为空,从 .list 文件解析
if not ref_text or not ref_audio_path:
parsed_audio, parsed_text = self._parse_list_file(exp_name)
ref_audio_path = parsed_audio
ref_text = parsed_text
# 获取模型版本以确定权重目录
version = task.config.get("version", "v2")
gpt_weight_dir = self._get_gpt_weight_dir(version)
sovits_weight_dir = self._get_sovits_weight_dir(version)
outputs = []
if inference_dir.exists() and inference_dir.is_dir():
# 扫描目录中的所有 .wav 文件
for file_path in inference_dir.glob("*.wav"):
filename = file_path.name
# 解析文件名获取模型信息
# 格式: {exp_name}_gpt-{gpt_name}_sovits-{sovits_name}.wav
gpt_model, sovits_model = self._parse_inference_filename(filename, exp_name)
# 构建模型完整路径
gpt_path = str(Path(settings.EXP_ROOT) / exp_name / gpt_weight_dir / f"{gpt_model}.ckpt")
sovits_path = str(Path(settings.EXP_ROOT) / exp_name / sovits_weight_dir / f"{sovits_model}.pth")
# 获取文件信息
stat = file_path.stat()
outputs.append(InferenceOutputItem(
filename=filename,
gpt_model=gpt_model,
sovits_model=sovits_model,
gpt_path=gpt_path,
sovits_path=sovits_path,
file_path=str(file_path.relative_to(settings.PROJECT_ROOT)),
size_bytes=stat.st_size,
created_at=datetime.fromtimestamp(stat.st_ctime),
))
return InferenceOutputsResponse(
task_id=task_id,
exp_name=exp_name,
ref_text=ref_text,
ref_audio_path=ref_audio_path,
target_text=target_text,
outputs=outputs,
total=len(outputs),
)
async def download_inference_output(
self,
task_id: str,
filename: str
) -> Optional[Tuple[bytes, str, str]]:
"""
下载指定的推理输出文件
Args:
task_id: 任务 ID
filename: 文件名
Returns:
(文件内容, 文件名, content_type) 或 None(不存在时)
"""
# 获取任务
task = await self.db.get_task(task_id)
if not task:
return None
exp_name = task.exp_name
file_path = Path(settings.EXP_ROOT) / exp_name / "inference" / filename
# 安全检查:确保文件在预期目录内
try:
file_path = file_path.resolve()
expected_parent = (Path(settings.EXP_ROOT) / exp_name / "inference").resolve()
if not str(file_path).startswith(str(expected_parent)):
return None
except (ValueError, OSError):
return None
if not file_path.exists() or not file_path.is_file():
return None
# 读取文件内容
with open(file_path, "rb") as f:
file_data = f.read()
return file_data, filename, "audio/wav"
async def download_file(
self,
task_id: str,
file_type: str,
filename: str
) -> Optional[Tuple[bytes, str, str]]:
"""
下载指定类型的文件
Args:
task_id: 任务 ID
file_type: 文件类型 (output/ref_audio/gpt_model/sovits_model)
filename: 文件名
Returns:
(文件内容, 文件名, content_type) 或 None(不存在时)
"""
# 获取任务
task = await self.db.get_task(task_id)
if not task:
return None
exp_name = task.exp_name
version = task.config.get("version", "v2")
# 根据文件类型确定路径和 content_type
if file_type == "output":
file_path = Path(settings.EXP_ROOT) / exp_name / "inference" / filename
content_type = "audio/wav"
elif file_type == "ref_audio":
# ref_audio 使用完整路径(filename 参数实际上是完整路径)
file_path = Path(filename)
content_type = "audio/wav"
elif file_type == "gpt_model":
gpt_dir = self._get_gpt_weight_dir(version)
file_path = Path(settings.EXP_ROOT) / exp_name / gpt_dir / filename
content_type = "application/octet-stream"
elif file_type == "sovits_model":
sovits_dir = self._get_sovits_weight_dir(version)
file_path = Path(settings.EXP_ROOT) / exp_name / sovits_dir / filename
content_type = "application/octet-stream"
else:
return None
# 安全检查:确保文件路径有效
try:
file_path = file_path.resolve()
except (ValueError, OSError):
return None
if not file_path.exists() or not file_path.is_file():
return None
# 读取文件内容
with open(file_path, "rb") as f:
file_data = f.read()
# 使用文件名(不含路径)作为下载文件名
download_filename = file_path.name
return file_data, download_filename, content_type
def _parse_inference_filename(self, filename: str, exp_name: str) -> Tuple[str, str]:
"""
解析推理输出文件名,提取 GPT 和 SoVITS 模型名称
文件名格式: {exp_name}_gpt-{gpt_name}_sovits-{sovits_name}.wav
Args:
filename: 文件名
exp_name: 实验名称
Returns:
(gpt_model, sovits_model)
"""
# 移除扩展名
name = filename.rsplit(".", 1)[0] if "." in filename else filename
# 尝试解析模型名称
# 格式: {exp_name}_gpt-{gpt_name}_sovits-{sovits_name}
pattern = rf"^{re.escape(exp_name)}_gpt-(.+)_sovits-(.+)$"
match = re.match(pattern, name)
if match:
return match.group(1), match.group(2)
# 备用解析:尝试匹配 gpt- 和 sovits- 部分
gpt_match = re.search(r"gpt-([^_]+)", name)
sovits_match = re.search(r"sovits-([^_]+)", name)
gpt_model = gpt_match.group(1) if gpt_match else "unknown"
sovits_model = sovits_match.group(1) if sovits_match else "unknown"
return gpt_model, sovits_model
def _parse_list_file(self, exp_name: str) -> Tuple[str, str]:
"""
从 asr_opt/slicer_opt.list 解析第一行获取 ref_audio_path 和 ref_text
Args:
exp_name: 实验名称
Returns:
(ref_audio_path, ref_text) 元组,解析失败返回空字符串
"""
list_path = Path(settings.EXP_ROOT) / exp_name / 'asr_opt' / 'slicer_opt.list'
if not list_path.exists():
return "", ""
with open(list_path, 'r', encoding='utf-8') as f:
first_line = f.readline().strip()
if not first_line:
return "", ""
# 格式: {音频路径}|{文件夹名}|{语言}|{识别文本}
parts = first_line.split('|')
if len(parts) >= 4:
return parts[0], parts[3]
return "", ""
def _get_gpt_weight_dir(self, version: str) -> str:
"""根据模型版本获取 GPT 权重目录名"""
version_to_dir = {
"v1": "GPT_weights",
"v2": "GPT_weights_v2",
"v3": "GPT_weights_v3",
"v4": "GPT_weights_v4",
"v2Pro": "GPT_weights_v2Pro",
"v2ProPlus": "GPT_weights_v2ProPlus",
}
return version_to_dir.get(version, "GPT_weights_v2")
def _get_sovits_weight_dir(self, version: str) -> str:
"""根据模型版本获取 SoVITS 权重目录名"""
version_to_dir = {
"v1": "SoVITS_weights",
"v2": "SoVITS_weights_v2",
"v3": "SoVITS_weights_v3",
"v4": "SoVITS_weights_v4",
"v2Pro": "SoVITS_weights_v2Pro",
"v2ProPlus": "SoVITS_weights_v2ProPlus",
}
return version_to_dir.get(version, "SoVITS_weights_v2")
def _task_to_response(self, task: Task) -> TaskResponse:
"""将 Task 领域模型转换为 TaskResponse"""
return TaskResponse(
id=task.id,
exp_name=task.exp_name,
status=task.status.value if isinstance(task.status, TaskStatus) else task.status,
current_stage=task.current_stage,
progress=task.stage_progress,
overall_progress=task.progress,
message=task.message,
error_message=task.error_message,
created_at=task.created_at,
started_at=task.started_at,
completed_at=task.completed_at,
)