liumaolin
refactor(config): centralize configuration management in `project_config`
8f68d0a
"""
SQLite 数据库适配器
基于 SQLite + aiosqlite 实现的数据库适配器,适用于 macOS 本地训练场景。
"""
import json
import sqlite3
import uuid
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional
import aiosqlite
from project_config import settings
from ..base import DatabaseAdapter
from ...models.domain import Task, TaskStatus
# 阶段类型列表
STAGE_TYPES = [
"audio_slice",
"asr",
"text_feature",
"hubert_feature",
"semantic_token",
"sovits_train",
"gpt_train",
]
class SQLiteAdapter(DatabaseAdapter):
"""
SQLite 数据库适配器
特点:
1. 使用 aiosqlite 实现异步数据库操作
2. 支持 Task (Quick Mode) 和 Experiment (Advanced Mode) 管理
3. 自动初始化数据库表结构
表结构:
- tasks: Quick Mode 任务
- experiments: Advanced Mode 实验
- stages: 实验阶段状态
- files: 文件记录
Example:
>>> adapter = SQLiteAdapter()
>>> task = Task(id="task-123", exp_name="my_voice", config={})
>>> await adapter.create_task(task)
>>> task = await adapter.get_task("task-123")
"""
def __init__(self, db_path: Optional[str] = None):
"""
初始化 SQLite 适配器
Args:
db_path: 数据库文件路径,默认使用 settings.SQLITE_PATH
"""
if db_path:
self.db_path = db_path
else:
self.db_path = str(settings.SQLITE_PATH)
# 确保目录存在
Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
# 同步初始化数据库
self._init_db_sync()
def _init_db_sync(self) -> None:
"""同步初始化数据库表结构"""
with sqlite3.connect(self.db_path) as conn:
# Tasks 表 (Quick Mode)
conn.execute('''
CREATE TABLE IF NOT EXISTS tasks (
id TEXT PRIMARY KEY,
job_id TEXT,
exp_name TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'queued',
config TEXT,
current_stage TEXT,
progress REAL DEFAULT 0,
stage_progress REAL DEFAULT 0,
message TEXT,
error_message TEXT,
created_at TEXT NOT NULL,
started_at TEXT,
completed_at TEXT
)
''')
conn.execute('CREATE INDEX IF NOT EXISTS idx_tasks_status ON tasks(status)')
conn.execute('CREATE INDEX IF NOT EXISTS idx_tasks_created ON tasks(created_at)')
# Experiments 表 (Advanced Mode)
conn.execute('''
CREATE TABLE IF NOT EXISTS experiments (
id TEXT PRIMARY KEY,
exp_name TEXT NOT NULL,
version TEXT NOT NULL DEFAULT 'v2',
exp_root TEXT DEFAULT 'logs',
gpu_numbers TEXT DEFAULT '0',
is_half INTEGER DEFAULT 1,
audio_file_id TEXT,
status TEXT NOT NULL DEFAULT 'created',
created_at TEXT NOT NULL,
updated_at TEXT
)
''')
conn.execute('CREATE INDEX IF NOT EXISTS idx_experiments_status ON experiments(status)')
conn.execute('CREATE INDEX IF NOT EXISTS idx_experiments_created ON experiments(created_at)')
# Stages 表 (Advanced Mode 阶段状态)
conn.execute('''
CREATE TABLE IF NOT EXISTS stages (
id TEXT PRIMARY KEY,
experiment_id TEXT NOT NULL,
stage_type TEXT NOT NULL,
status TEXT DEFAULT 'pending',
progress REAL DEFAULT 0,
message TEXT,
job_id TEXT,
config TEXT,
outputs TEXT,
started_at TEXT,
completed_at TEXT,
error_message TEXT,
FOREIGN KEY (experiment_id) REFERENCES experiments(id) ON DELETE CASCADE,
UNIQUE (experiment_id, stage_type)
)
''')
conn.execute('CREATE INDEX IF NOT EXISTS idx_stages_experiment ON stages(experiment_id)')
conn.execute('CREATE INDEX IF NOT EXISTS idx_stages_status ON stages(status)')
# Files 表 (文件记录)
conn.execute('''
CREATE TABLE IF NOT EXISTS files (
id TEXT PRIMARY KEY,
filename TEXT NOT NULL,
content_type TEXT,
size_bytes INTEGER DEFAULT 0,
purpose TEXT DEFAULT 'training',
duration_seconds REAL,
sample_rate INTEGER,
storage_path TEXT,
uploaded_at TEXT NOT NULL
)
''')
conn.execute('CREATE INDEX IF NOT EXISTS idx_files_purpose ON files(purpose)')
conn.execute('CREATE INDEX IF NOT EXISTS idx_files_uploaded ON files(uploaded_at)')
conn.commit()
# ============================================================
# Task CRUD (Quick Mode)
# ============================================================
async def create_task(self, task: Task) -> Task:
"""创建任务"""
async with aiosqlite.connect(self.db_path) as db:
await db.execute(
'''INSERT INTO tasks
(id, job_id, exp_name, status, config, current_stage,
progress, stage_progress, message, error_message,
created_at, started_at, completed_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)''',
(
task.id,
task.job_id,
task.exp_name,
task.status.value if isinstance(task.status, TaskStatus) else task.status,
json.dumps(task.config, ensure_ascii=False) if task.config else None,
task.current_stage,
task.progress,
task.stage_progress,
task.message,
task.error_message,
task.created_at.isoformat() if task.created_at else datetime.utcnow().isoformat(),
task.started_at.isoformat() if task.started_at else None,
task.completed_at.isoformat() if task.completed_at else None,
)
)
await db.commit()
return task
async def get_task(self, task_id: str) -> Optional[Task]:
"""获取任务"""
async with aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row
async with db.execute(
"SELECT * FROM tasks WHERE id = ?", (task_id,)
) as cursor:
row = await cursor.fetchone()
if row:
return self._row_to_task(dict(row))
return None
async def update_task(self, task_id: str, updates: Dict[str, Any]) -> Optional[Task]:
"""更新任务"""
if not updates:
return await self.get_task(task_id)
# 处理特殊字段
processed = {}
for key, value in updates.items():
if key == "status" and isinstance(value, TaskStatus):
processed[key] = value.value
elif key == "config" and isinstance(value, dict):
processed[key] = json.dumps(value, ensure_ascii=False)
elif key in ("created_at", "started_at", "completed_at") and isinstance(value, datetime):
processed[key] = value.isoformat()
else:
processed[key] = value
async with aiosqlite.connect(self.db_path) as db:
set_clause = ", ".join(f"{k} = ?" for k in processed.keys())
values = list(processed.values()) + [task_id]
await db.execute(
f"UPDATE tasks SET {set_clause} WHERE id = ?",
values
)
await db.commit()
return await self.get_task(task_id)
async def list_tasks(
self,
status: Optional[str] = None,
limit: int = 50,
offset: int = 0
) -> List[Task]:
"""查询任务列表"""
async with aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row
if status:
query = """
SELECT * FROM tasks
WHERE status = ?
ORDER BY created_at DESC
LIMIT ? OFFSET ?
"""
params = (status, limit, offset)
else:
query = """
SELECT * FROM tasks
ORDER BY created_at DESC
LIMIT ? OFFSET ?
"""
params = (limit, offset)
async with db.execute(query, params) as cursor:
rows = await cursor.fetchall()
return [self._row_to_task(dict(row)) for row in rows]
async def delete_task(self, task_id: str) -> bool:
"""删除任务"""
async with aiosqlite.connect(self.db_path) as db:
cursor = await db.execute(
"DELETE FROM tasks WHERE id = ?", (task_id,)
)
await db.commit()
return cursor.rowcount > 0
async def count_tasks(self, status: Optional[str] = None) -> int:
"""统计任务数量"""
async with aiosqlite.connect(self.db_path) as db:
if status:
async with db.execute(
"SELECT COUNT(*) FROM tasks WHERE status = ?", (status,)
) as cursor:
row = await cursor.fetchone()
else:
async with db.execute("SELECT COUNT(*) FROM tasks") as cursor:
row = await cursor.fetchone()
return row[0] if row else 0
async def get_task_by_exp_name(self, exp_name: str) -> Optional[Task]:
"""根据实验名称获取任务"""
async with aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row
async with db.execute(
"SELECT * FROM tasks WHERE exp_name = ? LIMIT 1", (exp_name,)
) as cursor:
row = await cursor.fetchone()
if row:
return self._row_to_task(dict(row))
return None
def _row_to_task(self, row: Dict[str, Any]) -> Task:
"""将数据库行转换为 Task 对象"""
# 解析 config JSON
config = row.get("config")
if config and isinstance(config, str):
try:
config = json.loads(config)
except json.JSONDecodeError:
config = {}
return Task.from_dict({
"id": row["id"],
"job_id": row.get("job_id"),
"exp_name": row["exp_name"],
"status": row.get("status", "queued"),
"config": config or {},
"current_stage": row.get("current_stage"),
"progress": row.get("progress", 0.0),
"stage_progress": row.get("stage_progress", 0.0),
"message": row.get("message"),
"error_message": row.get("error_message"),
"created_at": row.get("created_at"),
"started_at": row.get("started_at"),
"completed_at": row.get("completed_at"),
})
# ============================================================
# Experiment CRUD (Advanced Mode)
# ============================================================
async def create_experiment(self, experiment: Dict[str, Any]) -> Dict[str, Any]:
"""创建实验"""
exp_id = experiment.get("id") or f"exp-{uuid.uuid4().hex[:8]}"
now = datetime.utcnow().isoformat()
async with aiosqlite.connect(self.db_path) as db:
# 创建实验记录
await db.execute(
'''INSERT INTO experiments
(id, exp_name, version, exp_root, gpu_numbers, is_half,
audio_file_id, status, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)''',
(
exp_id,
experiment["exp_name"],
experiment.get("version", "v2"),
experiment.get("exp_root", "logs"),
experiment.get("gpu_numbers", "0"),
1 if experiment.get("is_half", True) else 0,
experiment.get("audio_file_id"),
experiment.get("status", "created"),
now,
now,
)
)
# 创建所有阶段的初始状态
for stage_type in STAGE_TYPES:
stage_id = f"{exp_id}-{stage_type}"
await db.execute(
'''INSERT INTO stages
(id, experiment_id, stage_type, status)
VALUES (?, ?, ?, 'pending')''',
(stage_id, exp_id, stage_type)
)
await db.commit()
return await self.get_experiment(exp_id)
async def get_experiment(self, exp_id: str) -> Optional[Dict[str, Any]]:
"""获取实验"""
async with aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row
# 获取实验基本信息
async with db.execute(
"SELECT * FROM experiments WHERE id = ?", (exp_id,)
) as cursor:
row = await cursor.fetchone()
if not row:
return None
experiment = dict(row)
experiment["is_half"] = bool(experiment.get("is_half", 1))
# 获取所有阶段状态
stages = {}
async with db.execute(
"SELECT * FROM stages WHERE experiment_id = ?", (exp_id,)
) as cursor:
stage_rows = await cursor.fetchall()
for stage_row in stage_rows:
stage = dict(stage_row)
stage_type = stage["stage_type"]
# 解析 JSON 字段
for json_field in ("config", "outputs"):
if stage.get(json_field) and isinstance(stage[json_field], str):
try:
stage[json_field] = json.loads(stage[json_field])
except json.JSONDecodeError:
stage[json_field] = None
stages[stage_type] = stage
experiment["stages"] = stages
return experiment
async def update_experiment(
self,
exp_id: str,
updates: Dict[str, Any]
) -> Optional[Dict[str, Any]]:
"""更新实验"""
if not updates:
return await self.get_experiment(exp_id)
# 处理 is_half 布尔值
processed = {}
for key, value in updates.items():
if key == "is_half":
processed[key] = 1 if value else 0
elif key == "updated_at" and isinstance(value, datetime):
processed[key] = value.isoformat()
elif key != "stages": # stages 单独处理
processed[key] = value
# 添加更新时间
if "updated_at" not in processed:
processed["updated_at"] = datetime.utcnow().isoformat()
async with aiosqlite.connect(self.db_path) as db:
if processed:
set_clause = ", ".join(f"{k} = ?" for k in processed.keys())
values = list(processed.values()) + [exp_id]
await db.execute(
f"UPDATE experiments SET {set_clause} WHERE id = ?",
values
)
await db.commit()
return await self.get_experiment(exp_id)
async def list_experiments(
self,
status: Optional[str] = None,
limit: int = 50,
offset: int = 0
) -> List[Dict[str, Any]]:
"""查询实验列表"""
async with aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row
if status:
query = """
SELECT * FROM experiments
WHERE status = ?
ORDER BY created_at DESC
LIMIT ? OFFSET ?
"""
params = (status, limit, offset)
else:
query = """
SELECT * FROM experiments
ORDER BY created_at DESC
LIMIT ? OFFSET ?
"""
params = (limit, offset)
async with db.execute(query, params) as cursor:
rows = await cursor.fetchall()
results = []
for row in rows:
exp = dict(row)
exp["is_half"] = bool(exp.get("is_half", 1))
# 简化列表,不包含完整的 stages
results.append(exp)
return results
async def delete_experiment(self, exp_id: str) -> bool:
"""删除实验及其阶段"""
async with aiosqlite.connect(self.db_path) as db:
# 先删除阶段
await db.execute(
"DELETE FROM stages WHERE experiment_id = ?", (exp_id,)
)
# 再删除实验
cursor = await db.execute(
"DELETE FROM experiments WHERE id = ?", (exp_id,)
)
await db.commit()
return cursor.rowcount > 0
# ============================================================
# Stage 操作 (Advanced Mode)
# ============================================================
async def update_stage(
self,
exp_id: str,
stage_type: str,
updates: Dict[str, Any]
) -> Optional[Dict[str, Any]]:
"""更新阶段状态"""
if not updates:
return await self.get_stage(exp_id, stage_type)
# 处理 JSON 字段
processed = {}
for key, value in updates.items():
if key in ("config", "outputs") and isinstance(value, dict):
processed[key] = json.dumps(value, ensure_ascii=False)
elif key in ("started_at", "completed_at") and isinstance(value, datetime):
processed[key] = value.isoformat()
else:
processed[key] = value
async with aiosqlite.connect(self.db_path) as db:
set_clause = ", ".join(f"{k} = ?" for k in processed.keys())
values = list(processed.values()) + [exp_id, stage_type]
await db.execute(
f"UPDATE stages SET {set_clause} WHERE experiment_id = ? AND stage_type = ?",
values
)
await db.commit()
# 同时更新实验的 updated_at
await self.update_experiment(exp_id, {})
return await self.get_stage(exp_id, stage_type)
async def get_stage(
self,
exp_id: str,
stage_type: str
) -> Optional[Dict[str, Any]]:
"""获取阶段状态"""
async with aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row
async with db.execute(
"SELECT * FROM stages WHERE experiment_id = ? AND stage_type = ?",
(exp_id, stage_type)
) as cursor:
row = await cursor.fetchone()
if not row:
return None
stage = dict(row)
# 解析 JSON 字段
for json_field in ("config", "outputs"):
if stage.get(json_field) and isinstance(stage[json_field], str):
try:
stage[json_field] = json.loads(stage[json_field])
except json.JSONDecodeError:
stage[json_field] = None
return stage
async def get_all_stages(self, exp_id: str) -> List[Dict[str, Any]]:
"""获取实验的所有阶段状态"""
async with aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row
async with db.execute(
"SELECT * FROM stages WHERE experiment_id = ? ORDER BY id",
(exp_id,)
) as cursor:
rows = await cursor.fetchall()
results = []
for row in rows:
stage = dict(row)
# 解析 JSON 字段
for json_field in ("config", "outputs"):
if stage.get(json_field) and isinstance(stage[json_field], str):
try:
stage[json_field] = json.loads(stage[json_field])
except json.JSONDecodeError:
stage[json_field] = None
results.append(stage)
return results
# ============================================================
# File 记录
# ============================================================
async def create_file_record(self, file_data: Dict[str, Any]) -> Dict[str, Any]:
"""创建文件记录"""
file_id = file_data.get("id") or str(uuid.uuid4())
now = datetime.utcnow().isoformat()
async with aiosqlite.connect(self.db_path) as db:
await db.execute(
'''INSERT INTO files
(id, filename, content_type, size_bytes, purpose,
duration_seconds, sample_rate, storage_path, uploaded_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)''',
(
file_id,
file_data["filename"],
file_data.get("content_type"),
file_data.get("size_bytes", 0),
file_data.get("purpose", "training"),
file_data.get("duration_seconds"),
file_data.get("sample_rate"),
file_data.get("storage_path"),
file_data.get("uploaded_at", now),
)
)
await db.commit()
return await self.get_file_record(file_id)
async def get_file_record(self, file_id: str) -> Optional[Dict[str, Any]]:
"""获取文件记录"""
async with aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row
async with db.execute(
"SELECT * FROM files WHERE id = ?", (file_id,)
) as cursor:
row = await cursor.fetchone()
if row:
return dict(row)
return None
async def delete_file_record(self, file_id: str) -> bool:
"""删除文件记录"""
async with aiosqlite.connect(self.db_path) as db:
cursor = await db.execute(
"DELETE FROM files WHERE id = ?", (file_id,)
)
await db.commit()
return cursor.rowcount > 0
async def list_file_records(
self,
purpose: Optional[str] = None,
limit: int = 50,
offset: int = 0
) -> List[Dict[str, Any]]:
"""查询文件记录列表"""
async with aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row
if purpose:
query = """
SELECT * FROM files
WHERE purpose = ?
ORDER BY uploaded_at DESC
LIMIT ? OFFSET ?
"""
params = (purpose, limit, offset)
else:
query = """
SELECT * FROM files
ORDER BY uploaded_at DESC
LIMIT ? OFFSET ?
"""
params = (limit, offset)
async with db.execute(query, params) as cursor:
rows = await cursor.fetchall()
return [dict(row) for row in rows]
async def count_file_records(self, purpose: Optional[str] = None) -> int:
"""统计文件记录数量"""
async with aiosqlite.connect(self.db_path) as db:
if purpose:
async with db.execute(
"SELECT COUNT(*) FROM files WHERE purpose = ?", (purpose,)
) as cursor:
row = await cursor.fetchone()
else:
async with db.execute("SELECT COUNT(*) FROM files") as cursor:
row = await cursor.fetchone()
return row[0] if row else 0