|
|
""" |
|
|
Quick Mode 任务服务 |
|
|
|
|
|
处理一键训练任务的业务逻辑 |
|
|
""" |
|
|
|
|
|
import re |
|
|
import uuid |
|
|
from datetime import datetime |
|
|
from pathlib import Path |
|
|
from typing import AsyncGenerator, Dict, Optional, Any, Tuple |
|
|
|
|
|
from project_config import settings |
|
|
from ..core.adapters import get_database_adapter, get_task_queue_adapter, get_storage_adapter |
|
|
from ..models.domain import Task, TaskStatus |
|
|
from ..models.schemas.task import ( |
|
|
QuickModeRequest, |
|
|
TaskResponse, |
|
|
TaskListResponse, |
|
|
InferenceOutputItem, |
|
|
InferenceOutputsResponse, |
|
|
) |
|
|
|
|
|
|
|
|
QUALITY_PRESETS = { |
|
|
"fast": { |
|
|
"sovits_epochs": 4, |
|
|
"gpt_epochs": 8, |
|
|
"description": "快速训练,约10分钟", |
|
|
}, |
|
|
"standard": { |
|
|
"sovits_epochs": 8, |
|
|
"gpt_epochs": 15, |
|
|
"description": "标准训练,约20分钟", |
|
|
}, |
|
|
"high": { |
|
|
"sovits_epochs": 16, |
|
|
"gpt_epochs": 30, |
|
|
"description": "高质量训练,约40分钟", |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
DEFAULT_TARGET_TEXTS = { |
|
|
"zh": "这是一段测试语音合成的文本。请你用自然、清晰、不过度夸张的语气朗读,并在逗号和句号处做适当停顿:先慢一点,再稍微快一点,最后恢复正常语速。", |
|
|
"en": "This is a test text for speech synthesis. Please read it naturally and clearly, without exaggeration, pausing appropriately at commas and periods: start slowly, then speed up a bit, and finally return to normal pace.", |
|
|
"ja": "これは音声合成のテストテキストです。自然で明瞭に、大げさにならないように朗読してください。読点と句点で適切に間を置いて:最初はゆっくり、少し速く、最後は普通の速さに戻してください。", |
|
|
"ko": "이것은 음성 합성을 위한 테스트 텍스트입니다. 자연스럽고 명확하게, 과장하지 않고 읽어주세요. 쉼표와 마침표에서 적절히 멈추며: 먼저 천천히, 그 다음 조금 빠르게, 마지막으로 보통 속도로 돌아오세요.", |
|
|
"yue": "呢段係測試語音合成嘅文字。請你用自然、清楚、唔好太誇張嘅語氣讀出嚟,喺逗號同句號嗰度要適當噉停一停:開頭慢啲,跟住快少少,最後返返正常語速。", |
|
|
} |
|
|
|
|
|
|
|
|
class TaskService: |
|
|
""" |
|
|
Quick Mode 任务服务 |
|
|
|
|
|
提供一键训练任务的完整生命周期管理: |
|
|
- 创建任务 |
|
|
- 查询任务状态 |
|
|
- 取消任务 |
|
|
- 订阅进度更新 |
|
|
|
|
|
Example: |
|
|
>>> service = TaskService() |
|
|
>>> task = await service.create_quick_task(request) |
|
|
>>> status = await service.get_task(task.id) |
|
|
>>> await service.cancel_task(task.id) |
|
|
""" |
|
|
|
|
|
def __init__(self): |
|
|
"""初始化服务""" |
|
|
self._db = None |
|
|
self._queue = None |
|
|
self._storage = None |
|
|
|
|
|
@property |
|
|
def db(self): |
|
|
"""延迟获取数据库适配器""" |
|
|
if self._db is None: |
|
|
self._db = get_database_adapter() |
|
|
return self._db |
|
|
|
|
|
@property |
|
|
def queue(self): |
|
|
"""延迟获取任务队列适配器""" |
|
|
if self._queue is None: |
|
|
self._queue = get_task_queue_adapter() |
|
|
return self._queue |
|
|
|
|
|
@property |
|
|
def storage(self): |
|
|
"""延迟获取存储适配器""" |
|
|
if self._storage is None: |
|
|
self._storage = get_storage_adapter() |
|
|
return self._storage |
|
|
|
|
|
async def check_exp_name_exists(self, exp_name: str) -> bool: |
|
|
""" |
|
|
检查实验名称是否已存在 |
|
|
|
|
|
Args: |
|
|
exp_name: 实验名称 |
|
|
|
|
|
Returns: |
|
|
如果存在返回 True,否则返回 False |
|
|
""" |
|
|
existing_task = await self.db.get_task_by_exp_name(exp_name) |
|
|
return existing_task is not None |
|
|
|
|
|
async def validate_audio_file(self, audio_file_id: str) -> tuple[bool, str]: |
|
|
""" |
|
|
验证音频文件是否存在 |
|
|
|
|
|
Args: |
|
|
audio_file_id: 音频文件 ID 或路径 |
|
|
|
|
|
Returns: |
|
|
(是否存在, 实际文件路径) |
|
|
""" |
|
|
import os |
|
|
|
|
|
|
|
|
file_metadata = await self.storage.get_file_metadata(audio_file_id) |
|
|
|
|
|
if file_metadata: |
|
|
|
|
|
audio_file_path = str(self.storage.base_path / audio_file_id) |
|
|
exists = os.path.exists(audio_file_path) |
|
|
return exists, audio_file_path |
|
|
else: |
|
|
|
|
|
exists = os.path.exists(audio_file_id) |
|
|
return exists, audio_file_id |
|
|
|
|
|
async def create_quick_task(self, request: QuickModeRequest) -> TaskResponse: |
|
|
""" |
|
|
创建一键训练任务 |
|
|
|
|
|
根据请求参数和质量预设,自动配置训练参数并创建任务。 |
|
|
|
|
|
Args: |
|
|
request: 快速模式请求 |
|
|
|
|
|
Returns: |
|
|
TaskResponse: 任务响应 |
|
|
""" |
|
|
|
|
|
task_id = f"task-{uuid.uuid4().hex[:12]}" |
|
|
|
|
|
|
|
|
quality = request.options.quality |
|
|
preset = QUALITY_PRESETS.get(quality, QUALITY_PRESETS["standard"]) |
|
|
|
|
|
|
|
|
audio_file_id = request.audio_file_id |
|
|
_, audio_file_path = await self.validate_audio_file(audio_file_id) |
|
|
|
|
|
|
|
|
stages = [ |
|
|
"audio_slice", |
|
|
"asr", |
|
|
"text_feature", |
|
|
"hubert_feature", |
|
|
"semantic_token", |
|
|
"sovits_train", |
|
|
"gpt_train", |
|
|
] |
|
|
|
|
|
|
|
|
inference_opts = request.options.inference |
|
|
inference_enabled = inference_opts is None or inference_opts.enabled |
|
|
|
|
|
|
|
|
if inference_enabled: |
|
|
stages.append("inference") |
|
|
|
|
|
|
|
|
config = { |
|
|
"exp_name": request.exp_name, |
|
|
"audio_file_id": audio_file_id, |
|
|
"input_path": audio_file_path, |
|
|
"version": request.options.version, |
|
|
"language": request.options.language, |
|
|
"quality": quality, |
|
|
|
|
|
"total_epoch": preset["sovits_epochs"], |
|
|
"sovits_epochs": preset["sovits_epochs"], |
|
|
"gpt_epochs": preset["gpt_epochs"], |
|
|
|
|
|
"bert_pretrained_dir": str(settings.BERT_PRETRAINED_DIR), |
|
|
"ssl_pretrained_dir": str(settings.SSL_PRETRAINED_DIR), |
|
|
"pretrained_s2G": str(settings.PRETRAINED_S2G), |
|
|
"pretrained_s2D": str(settings.PRETRAINED_S2D), |
|
|
"pretrained_s1": str(settings.PRETRAINED_S1), |
|
|
|
|
|
"stages": stages, |
|
|
} |
|
|
|
|
|
|
|
|
if inference_enabled: |
|
|
if inference_opts: |
|
|
config["ref_text"] = inference_opts.ref_text or "" |
|
|
config["ref_audio_path"] = inference_opts.ref_audio_path or "" |
|
|
config["target_text"] = inference_opts.target_text |
|
|
else: |
|
|
|
|
|
config["ref_text"] = "" |
|
|
config["ref_audio_path"] = "" |
|
|
|
|
|
language = request.options.language |
|
|
config["target_text"] = DEFAULT_TARGET_TEXTS.get(language, DEFAULT_TARGET_TEXTS["zh"]) |
|
|
|
|
|
|
|
|
task = Task( |
|
|
id=task_id, |
|
|
exp_name=request.exp_name, |
|
|
config=config, |
|
|
status=TaskStatus.QUEUED, |
|
|
created_at=datetime.utcnow(), |
|
|
) |
|
|
|
|
|
|
|
|
await self.db.create_task(task) |
|
|
|
|
|
|
|
|
job_id = await self.queue.enqueue(task_id, config) |
|
|
|
|
|
|
|
|
await self.db.update_task(task_id, {"job_id": job_id}) |
|
|
task.job_id = job_id |
|
|
|
|
|
return self._task_to_response(task) |
|
|
|
|
|
async def get_task(self, task_id: str) -> Optional[TaskResponse]: |
|
|
""" |
|
|
获取任务详情 |
|
|
|
|
|
Args: |
|
|
task_id: 任务ID |
|
|
|
|
|
Returns: |
|
|
TaskResponse 或 None(不存在时) |
|
|
""" |
|
|
task = await self.db.get_task(task_id) |
|
|
if not task: |
|
|
return None |
|
|
return self._task_to_response(task) |
|
|
|
|
|
async def list_tasks( |
|
|
self, |
|
|
status: Optional[str] = None, |
|
|
limit: int = 50, |
|
|
offset: int = 0 |
|
|
) -> TaskListResponse: |
|
|
""" |
|
|
获取任务列表 |
|
|
|
|
|
Args: |
|
|
status: 按状态筛选 |
|
|
limit: 每页数量 |
|
|
offset: 偏移量 |
|
|
|
|
|
Returns: |
|
|
TaskListResponse |
|
|
""" |
|
|
tasks = await self.db.list_tasks(status=status, limit=limit, offset=offset) |
|
|
total = await self.db.count_tasks(status=status) |
|
|
|
|
|
return TaskListResponse( |
|
|
items=[self._task_to_response(t) for t in tasks], |
|
|
total=total, |
|
|
limit=limit, |
|
|
offset=offset, |
|
|
) |
|
|
|
|
|
async def cancel_task(self, task_id: str) -> bool: |
|
|
""" |
|
|
取消任务 |
|
|
|
|
|
Args: |
|
|
task_id: 任务ID |
|
|
|
|
|
Returns: |
|
|
是否成功取消 |
|
|
""" |
|
|
|
|
|
task = await self.db.get_task(task_id) |
|
|
if not task: |
|
|
return False |
|
|
|
|
|
|
|
|
if task.status not in (TaskStatus.QUEUED, TaskStatus.RUNNING): |
|
|
return False |
|
|
|
|
|
|
|
|
if task.job_id: |
|
|
await self.queue.cancel(task.job_id) |
|
|
|
|
|
|
|
|
await self.db.update_task(task_id, { |
|
|
"status": TaskStatus.CANCELLED, |
|
|
"completed_at": datetime.utcnow(), |
|
|
"message": "任务已取消", |
|
|
}) |
|
|
|
|
|
return True |
|
|
|
|
|
async def subscribe_progress( |
|
|
self, |
|
|
task_id: str |
|
|
) -> AsyncGenerator[Dict[str, Any], None]: |
|
|
""" |
|
|
订阅任务进度(SSE 流) |
|
|
|
|
|
Args: |
|
|
task_id: 任务ID |
|
|
|
|
|
Yields: |
|
|
进度信息字典 |
|
|
""" |
|
|
|
|
|
task = await self.db.get_task(task_id) |
|
|
if not task: |
|
|
yield {"type": "error", "message": "任务不存在"} |
|
|
return |
|
|
|
|
|
|
|
|
if task.status in (TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED): |
|
|
yield { |
|
|
"type": "final", |
|
|
"status": task.status.value, |
|
|
"message": task.message or task.error_message, |
|
|
"progress": task.progress, |
|
|
} |
|
|
return |
|
|
|
|
|
|
|
|
async for progress in self.queue.subscribe_progress(task_id): |
|
|
yield progress |
|
|
|
|
|
|
|
|
if progress.get("status") in ("completed", "failed", "cancelled"): |
|
|
break |
|
|
|
|
|
async def get_inference_outputs(self, task_id: str) -> Optional[InferenceOutputsResponse]: |
|
|
""" |
|
|
获取任务的推理输出列表 |
|
|
|
|
|
扫描 logs/{exp_name}/inference/ 目录,返回所有推理生成的音频文件元信息。 |
|
|
|
|
|
Args: |
|
|
task_id: 任务 ID |
|
|
|
|
|
Returns: |
|
|
InferenceOutputsResponse 或 None(任务不存在时) |
|
|
""" |
|
|
|
|
|
task = await self.db.get_task(task_id) |
|
|
if not task: |
|
|
return None |
|
|
|
|
|
exp_name = task.exp_name |
|
|
inference_dir = Path(settings.EXP_ROOT) / exp_name / "inference" |
|
|
|
|
|
|
|
|
ref_text = task.config.get("ref_text", "") |
|
|
ref_audio_path = task.config.get("ref_audio_path", "") |
|
|
target_text = task.config.get("target_text", "") |
|
|
|
|
|
|
|
|
if not ref_text or not ref_audio_path: |
|
|
parsed_audio, parsed_text = self._parse_list_file(exp_name) |
|
|
ref_audio_path = parsed_audio |
|
|
ref_text = parsed_text |
|
|
|
|
|
|
|
|
version = task.config.get("version", "v2") |
|
|
gpt_weight_dir = self._get_gpt_weight_dir(version) |
|
|
sovits_weight_dir = self._get_sovits_weight_dir(version) |
|
|
|
|
|
outputs = [] |
|
|
|
|
|
if inference_dir.exists() and inference_dir.is_dir(): |
|
|
|
|
|
for file_path in inference_dir.glob("*.wav"): |
|
|
filename = file_path.name |
|
|
|
|
|
|
|
|
|
|
|
gpt_model, sovits_model = self._parse_inference_filename(filename, exp_name) |
|
|
|
|
|
|
|
|
gpt_path = str(Path(settings.EXP_ROOT) / exp_name / gpt_weight_dir / f"{gpt_model}.ckpt") |
|
|
sovits_path = str(Path(settings.EXP_ROOT) / exp_name / sovits_weight_dir / f"{sovits_model}.pth") |
|
|
|
|
|
|
|
|
stat = file_path.stat() |
|
|
|
|
|
outputs.append(InferenceOutputItem( |
|
|
filename=filename, |
|
|
gpt_model=gpt_model, |
|
|
sovits_model=sovits_model, |
|
|
gpt_path=gpt_path, |
|
|
sovits_path=sovits_path, |
|
|
file_path=str(file_path.relative_to(settings.PROJECT_ROOT)), |
|
|
size_bytes=stat.st_size, |
|
|
created_at=datetime.fromtimestamp(stat.st_ctime), |
|
|
)) |
|
|
|
|
|
return InferenceOutputsResponse( |
|
|
task_id=task_id, |
|
|
exp_name=exp_name, |
|
|
ref_text=ref_text, |
|
|
ref_audio_path=ref_audio_path, |
|
|
target_text=target_text, |
|
|
outputs=outputs, |
|
|
total=len(outputs), |
|
|
) |
|
|
|
|
|
async def download_inference_output( |
|
|
self, |
|
|
task_id: str, |
|
|
filename: str |
|
|
) -> Optional[Tuple[bytes, str, str]]: |
|
|
""" |
|
|
下载指定的推理输出文件 |
|
|
|
|
|
Args: |
|
|
task_id: 任务 ID |
|
|
filename: 文件名 |
|
|
|
|
|
Returns: |
|
|
(文件内容, 文件名, content_type) 或 None(不存在时) |
|
|
""" |
|
|
|
|
|
task = await self.db.get_task(task_id) |
|
|
if not task: |
|
|
return None |
|
|
|
|
|
exp_name = task.exp_name |
|
|
file_path = Path(settings.EXP_ROOT) / exp_name / "inference" / filename |
|
|
|
|
|
|
|
|
try: |
|
|
file_path = file_path.resolve() |
|
|
expected_parent = (Path(settings.EXP_ROOT) / exp_name / "inference").resolve() |
|
|
if not str(file_path).startswith(str(expected_parent)): |
|
|
return None |
|
|
except (ValueError, OSError): |
|
|
return None |
|
|
|
|
|
if not file_path.exists() or not file_path.is_file(): |
|
|
return None |
|
|
|
|
|
|
|
|
with open(file_path, "rb") as f: |
|
|
file_data = f.read() |
|
|
|
|
|
return file_data, filename, "audio/wav" |
|
|
|
|
|
async def download_file( |
|
|
self, |
|
|
task_id: str, |
|
|
file_type: str, |
|
|
filename: str |
|
|
) -> Optional[Tuple[bytes, str, str]]: |
|
|
""" |
|
|
下载指定类型的文件 |
|
|
|
|
|
Args: |
|
|
task_id: 任务 ID |
|
|
file_type: 文件类型 (output/ref_audio/gpt_model/sovits_model) |
|
|
filename: 文件名 |
|
|
|
|
|
Returns: |
|
|
(文件内容, 文件名, content_type) 或 None(不存在时) |
|
|
""" |
|
|
|
|
|
task = await self.db.get_task(task_id) |
|
|
if not task: |
|
|
return None |
|
|
|
|
|
exp_name = task.exp_name |
|
|
version = task.config.get("version", "v2") |
|
|
|
|
|
|
|
|
if file_type == "output": |
|
|
file_path = Path(settings.EXP_ROOT) / exp_name / "inference" / filename |
|
|
content_type = "audio/wav" |
|
|
elif file_type == "ref_audio": |
|
|
|
|
|
file_path = Path(filename) |
|
|
content_type = "audio/wav" |
|
|
elif file_type == "gpt_model": |
|
|
gpt_dir = self._get_gpt_weight_dir(version) |
|
|
file_path = Path(settings.EXP_ROOT) / exp_name / gpt_dir / filename |
|
|
content_type = "application/octet-stream" |
|
|
elif file_type == "sovits_model": |
|
|
sovits_dir = self._get_sovits_weight_dir(version) |
|
|
file_path = Path(settings.EXP_ROOT) / exp_name / sovits_dir / filename |
|
|
content_type = "application/octet-stream" |
|
|
else: |
|
|
return None |
|
|
|
|
|
|
|
|
try: |
|
|
file_path = file_path.resolve() |
|
|
except (ValueError, OSError): |
|
|
return None |
|
|
|
|
|
if not file_path.exists() or not file_path.is_file(): |
|
|
return None |
|
|
|
|
|
|
|
|
with open(file_path, "rb") as f: |
|
|
file_data = f.read() |
|
|
|
|
|
|
|
|
download_filename = file_path.name |
|
|
|
|
|
return file_data, download_filename, content_type |
|
|
|
|
|
def _parse_inference_filename(self, filename: str, exp_name: str) -> Tuple[str, str]: |
|
|
""" |
|
|
解析推理输出文件名,提取 GPT 和 SoVITS 模型名称 |
|
|
|
|
|
文件名格式: {exp_name}_gpt-{gpt_name}_sovits-{sovits_name}.wav |
|
|
|
|
|
Args: |
|
|
filename: 文件名 |
|
|
exp_name: 实验名称 |
|
|
|
|
|
Returns: |
|
|
(gpt_model, sovits_model) |
|
|
""" |
|
|
|
|
|
name = filename.rsplit(".", 1)[0] if "." in filename else filename |
|
|
|
|
|
|
|
|
|
|
|
pattern = rf"^{re.escape(exp_name)}_gpt-(.+)_sovits-(.+)$" |
|
|
match = re.match(pattern, name) |
|
|
|
|
|
if match: |
|
|
return match.group(1), match.group(2) |
|
|
|
|
|
|
|
|
gpt_match = re.search(r"gpt-([^_]+)", name) |
|
|
sovits_match = re.search(r"sovits-([^_]+)", name) |
|
|
|
|
|
gpt_model = gpt_match.group(1) if gpt_match else "unknown" |
|
|
sovits_model = sovits_match.group(1) if sovits_match else "unknown" |
|
|
|
|
|
return gpt_model, sovits_model |
|
|
|
|
|
def _parse_list_file(self, exp_name: str) -> Tuple[str, str]: |
|
|
""" |
|
|
从 asr_opt/slicer_opt.list 解析第一行获取 ref_audio_path 和 ref_text |
|
|
|
|
|
Args: |
|
|
exp_name: 实验名称 |
|
|
|
|
|
Returns: |
|
|
(ref_audio_path, ref_text) 元组,解析失败返回空字符串 |
|
|
""" |
|
|
list_path = Path(settings.EXP_ROOT) / exp_name / 'asr_opt' / 'slicer_opt.list' |
|
|
if not list_path.exists(): |
|
|
return "", "" |
|
|
|
|
|
with open(list_path, 'r', encoding='utf-8') as f: |
|
|
first_line = f.readline().strip() |
|
|
|
|
|
if not first_line: |
|
|
return "", "" |
|
|
|
|
|
|
|
|
parts = first_line.split('|') |
|
|
if len(parts) >= 4: |
|
|
return parts[0], parts[3] |
|
|
return "", "" |
|
|
|
|
|
def _get_gpt_weight_dir(self, version: str) -> str: |
|
|
"""根据模型版本获取 GPT 权重目录名""" |
|
|
version_to_dir = { |
|
|
"v1": "GPT_weights", |
|
|
"v2": "GPT_weights_v2", |
|
|
"v3": "GPT_weights_v3", |
|
|
"v4": "GPT_weights_v4", |
|
|
"v2Pro": "GPT_weights_v2Pro", |
|
|
"v2ProPlus": "GPT_weights_v2ProPlus", |
|
|
} |
|
|
return version_to_dir.get(version, "GPT_weights_v2") |
|
|
|
|
|
def _get_sovits_weight_dir(self, version: str) -> str: |
|
|
"""根据模型版本获取 SoVITS 权重目录名""" |
|
|
version_to_dir = { |
|
|
"v1": "SoVITS_weights", |
|
|
"v2": "SoVITS_weights_v2", |
|
|
"v3": "SoVITS_weights_v3", |
|
|
"v4": "SoVITS_weights_v4", |
|
|
"v2Pro": "SoVITS_weights_v2Pro", |
|
|
"v2ProPlus": "SoVITS_weights_v2ProPlus", |
|
|
} |
|
|
return version_to_dir.get(version, "SoVITS_weights_v2") |
|
|
|
|
|
def _task_to_response(self, task: Task) -> TaskResponse: |
|
|
"""将 Task 领域模型转换为 TaskResponse""" |
|
|
return TaskResponse( |
|
|
id=task.id, |
|
|
exp_name=task.exp_name, |
|
|
status=task.status.value if isinstance(task.status, TaskStatus) else task.status, |
|
|
current_stage=task.current_stage, |
|
|
progress=task.stage_progress, |
|
|
overall_progress=task.progress, |
|
|
message=task.message, |
|
|
error_message=task.error_message, |
|
|
created_at=task.created_at, |
|
|
started_at=task.started_at, |
|
|
completed_at=task.completed_at, |
|
|
) |
|
|
|