""" SQLite 数据库适配器 基于 SQLite + aiosqlite 实现的数据库适配器,适用于 macOS 本地训练场景。 """ import json import sqlite3 import uuid from datetime import datetime from pathlib import Path from typing import Any, Dict, List, Optional import aiosqlite from project_config import settings from ..base import DatabaseAdapter from ...models.domain import Task, TaskStatus # 阶段类型列表 STAGE_TYPES = [ "audio_slice", "asr", "text_feature", "hubert_feature", "semantic_token", "sovits_train", "gpt_train", ] class SQLiteAdapter(DatabaseAdapter): """ SQLite 数据库适配器 特点: 1. 使用 aiosqlite 实现异步数据库操作 2. 支持 Task (Quick Mode) 和 Experiment (Advanced Mode) 管理 3. 自动初始化数据库表结构 表结构: - tasks: Quick Mode 任务 - experiments: Advanced Mode 实验 - stages: 实验阶段状态 - files: 文件记录 Example: >>> adapter = SQLiteAdapter() >>> task = Task(id="task-123", exp_name="my_voice", config={}) >>> await adapter.create_task(task) >>> task = await adapter.get_task("task-123") """ def __init__(self, db_path: Optional[str] = None): """ 初始化 SQLite 适配器 Args: db_path: 数据库文件路径,默认使用 settings.SQLITE_PATH """ if db_path: self.db_path = db_path else: self.db_path = str(settings.SQLITE_PATH) # 确保目录存在 Path(self.db_path).parent.mkdir(parents=True, exist_ok=True) # 同步初始化数据库 self._init_db_sync() def _init_db_sync(self) -> None: """同步初始化数据库表结构""" with sqlite3.connect(self.db_path) as conn: # Tasks 表 (Quick Mode) conn.execute(''' CREATE TABLE IF NOT EXISTS tasks ( id TEXT PRIMARY KEY, job_id TEXT, exp_name TEXT NOT NULL, status TEXT NOT NULL DEFAULT 'queued', config TEXT, current_stage TEXT, progress REAL DEFAULT 0, stage_progress REAL DEFAULT 0, message TEXT, error_message TEXT, created_at TEXT NOT NULL, started_at TEXT, completed_at TEXT ) ''') conn.execute('CREATE INDEX IF NOT EXISTS idx_tasks_status ON tasks(status)') conn.execute('CREATE INDEX IF NOT EXISTS idx_tasks_created ON tasks(created_at)') # Experiments 表 (Advanced Mode) conn.execute(''' CREATE TABLE IF NOT EXISTS experiments ( id TEXT PRIMARY KEY, exp_name TEXT NOT NULL, version TEXT NOT NULL DEFAULT 'v2', exp_root TEXT DEFAULT 'logs', gpu_numbers TEXT DEFAULT '0', is_half INTEGER DEFAULT 1, audio_file_id TEXT, status TEXT NOT NULL DEFAULT 'created', created_at TEXT NOT NULL, updated_at TEXT ) ''') conn.execute('CREATE INDEX IF NOT EXISTS idx_experiments_status ON experiments(status)') conn.execute('CREATE INDEX IF NOT EXISTS idx_experiments_created ON experiments(created_at)') # Stages 表 (Advanced Mode 阶段状态) conn.execute(''' CREATE TABLE IF NOT EXISTS stages ( id TEXT PRIMARY KEY, experiment_id TEXT NOT NULL, stage_type TEXT NOT NULL, status TEXT DEFAULT 'pending', progress REAL DEFAULT 0, message TEXT, job_id TEXT, config TEXT, outputs TEXT, started_at TEXT, completed_at TEXT, error_message TEXT, FOREIGN KEY (experiment_id) REFERENCES experiments(id) ON DELETE CASCADE, UNIQUE (experiment_id, stage_type) ) ''') conn.execute('CREATE INDEX IF NOT EXISTS idx_stages_experiment ON stages(experiment_id)') conn.execute('CREATE INDEX IF NOT EXISTS idx_stages_status ON stages(status)') # Files 表 (文件记录) conn.execute(''' CREATE TABLE IF NOT EXISTS files ( id TEXT PRIMARY KEY, filename TEXT NOT NULL, content_type TEXT, size_bytes INTEGER DEFAULT 0, purpose TEXT DEFAULT 'training', duration_seconds REAL, sample_rate INTEGER, storage_path TEXT, uploaded_at TEXT NOT NULL ) ''') conn.execute('CREATE INDEX IF NOT EXISTS idx_files_purpose ON files(purpose)') conn.execute('CREATE INDEX IF NOT EXISTS idx_files_uploaded ON files(uploaded_at)') conn.commit() # ============================================================ # Task CRUD (Quick Mode) # ============================================================ async def create_task(self, task: Task) -> Task: """创建任务""" async with aiosqlite.connect(self.db_path) as db: await db.execute( '''INSERT INTO tasks (id, job_id, exp_name, status, config, current_stage, progress, stage_progress, message, error_message, created_at, started_at, completed_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)''', ( task.id, task.job_id, task.exp_name, task.status.value if isinstance(task.status, TaskStatus) else task.status, json.dumps(task.config, ensure_ascii=False) if task.config else None, task.current_stage, task.progress, task.stage_progress, task.message, task.error_message, task.created_at.isoformat() if task.created_at else datetime.utcnow().isoformat(), task.started_at.isoformat() if task.started_at else None, task.completed_at.isoformat() if task.completed_at else None, ) ) await db.commit() return task async def get_task(self, task_id: str) -> Optional[Task]: """获取任务""" async with aiosqlite.connect(self.db_path) as db: db.row_factory = aiosqlite.Row async with db.execute( "SELECT * FROM tasks WHERE id = ?", (task_id,) ) as cursor: row = await cursor.fetchone() if row: return self._row_to_task(dict(row)) return None async def update_task(self, task_id: str, updates: Dict[str, Any]) -> Optional[Task]: """更新任务""" if not updates: return await self.get_task(task_id) # 处理特殊字段 processed = {} for key, value in updates.items(): if key == "status" and isinstance(value, TaskStatus): processed[key] = value.value elif key == "config" and isinstance(value, dict): processed[key] = json.dumps(value, ensure_ascii=False) elif key in ("created_at", "started_at", "completed_at") and isinstance(value, datetime): processed[key] = value.isoformat() else: processed[key] = value async with aiosqlite.connect(self.db_path) as db: set_clause = ", ".join(f"{k} = ?" for k in processed.keys()) values = list(processed.values()) + [task_id] await db.execute( f"UPDATE tasks SET {set_clause} WHERE id = ?", values ) await db.commit() return await self.get_task(task_id) async def list_tasks( self, status: Optional[str] = None, limit: int = 50, offset: int = 0 ) -> List[Task]: """查询任务列表""" async with aiosqlite.connect(self.db_path) as db: db.row_factory = aiosqlite.Row if status: query = """ SELECT * FROM tasks WHERE status = ? ORDER BY created_at DESC LIMIT ? OFFSET ? """ params = (status, limit, offset) else: query = """ SELECT * FROM tasks ORDER BY created_at DESC LIMIT ? OFFSET ? """ params = (limit, offset) async with db.execute(query, params) as cursor: rows = await cursor.fetchall() return [self._row_to_task(dict(row)) for row in rows] async def delete_task(self, task_id: str) -> bool: """删除任务""" async with aiosqlite.connect(self.db_path) as db: cursor = await db.execute( "DELETE FROM tasks WHERE id = ?", (task_id,) ) await db.commit() return cursor.rowcount > 0 async def count_tasks(self, status: Optional[str] = None) -> int: """统计任务数量""" async with aiosqlite.connect(self.db_path) as db: if status: async with db.execute( "SELECT COUNT(*) FROM tasks WHERE status = ?", (status,) ) as cursor: row = await cursor.fetchone() else: async with db.execute("SELECT COUNT(*) FROM tasks") as cursor: row = await cursor.fetchone() return row[0] if row else 0 async def get_task_by_exp_name(self, exp_name: str) -> Optional[Task]: """根据实验名称获取任务""" async with aiosqlite.connect(self.db_path) as db: db.row_factory = aiosqlite.Row async with db.execute( "SELECT * FROM tasks WHERE exp_name = ? LIMIT 1", (exp_name,) ) as cursor: row = await cursor.fetchone() if row: return self._row_to_task(dict(row)) return None def _row_to_task(self, row: Dict[str, Any]) -> Task: """将数据库行转换为 Task 对象""" # 解析 config JSON config = row.get("config") if config and isinstance(config, str): try: config = json.loads(config) except json.JSONDecodeError: config = {} return Task.from_dict({ "id": row["id"], "job_id": row.get("job_id"), "exp_name": row["exp_name"], "status": row.get("status", "queued"), "config": config or {}, "current_stage": row.get("current_stage"), "progress": row.get("progress", 0.0), "stage_progress": row.get("stage_progress", 0.0), "message": row.get("message"), "error_message": row.get("error_message"), "created_at": row.get("created_at"), "started_at": row.get("started_at"), "completed_at": row.get("completed_at"), }) # ============================================================ # Experiment CRUD (Advanced Mode) # ============================================================ async def create_experiment(self, experiment: Dict[str, Any]) -> Dict[str, Any]: """创建实验""" exp_id = experiment.get("id") or f"exp-{uuid.uuid4().hex[:8]}" now = datetime.utcnow().isoformat() async with aiosqlite.connect(self.db_path) as db: # 创建实验记录 await db.execute( '''INSERT INTO experiments (id, exp_name, version, exp_root, gpu_numbers, is_half, audio_file_id, status, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)''', ( exp_id, experiment["exp_name"], experiment.get("version", "v2"), experiment.get("exp_root", "logs"), experiment.get("gpu_numbers", "0"), 1 if experiment.get("is_half", True) else 0, experiment.get("audio_file_id"), experiment.get("status", "created"), now, now, ) ) # 创建所有阶段的初始状态 for stage_type in STAGE_TYPES: stage_id = f"{exp_id}-{stage_type}" await db.execute( '''INSERT INTO stages (id, experiment_id, stage_type, status) VALUES (?, ?, ?, 'pending')''', (stage_id, exp_id, stage_type) ) await db.commit() return await self.get_experiment(exp_id) async def get_experiment(self, exp_id: str) -> Optional[Dict[str, Any]]: """获取实验""" async with aiosqlite.connect(self.db_path) as db: db.row_factory = aiosqlite.Row # 获取实验基本信息 async with db.execute( "SELECT * FROM experiments WHERE id = ?", (exp_id,) ) as cursor: row = await cursor.fetchone() if not row: return None experiment = dict(row) experiment["is_half"] = bool(experiment.get("is_half", 1)) # 获取所有阶段状态 stages = {} async with db.execute( "SELECT * FROM stages WHERE experiment_id = ?", (exp_id,) ) as cursor: stage_rows = await cursor.fetchall() for stage_row in stage_rows: stage = dict(stage_row) stage_type = stage["stage_type"] # 解析 JSON 字段 for json_field in ("config", "outputs"): if stage.get(json_field) and isinstance(stage[json_field], str): try: stage[json_field] = json.loads(stage[json_field]) except json.JSONDecodeError: stage[json_field] = None stages[stage_type] = stage experiment["stages"] = stages return experiment async def update_experiment( self, exp_id: str, updates: Dict[str, Any] ) -> Optional[Dict[str, Any]]: """更新实验""" if not updates: return await self.get_experiment(exp_id) # 处理 is_half 布尔值 processed = {} for key, value in updates.items(): if key == "is_half": processed[key] = 1 if value else 0 elif key == "updated_at" and isinstance(value, datetime): processed[key] = value.isoformat() elif key != "stages": # stages 单独处理 processed[key] = value # 添加更新时间 if "updated_at" not in processed: processed["updated_at"] = datetime.utcnow().isoformat() async with aiosqlite.connect(self.db_path) as db: if processed: set_clause = ", ".join(f"{k} = ?" for k in processed.keys()) values = list(processed.values()) + [exp_id] await db.execute( f"UPDATE experiments SET {set_clause} WHERE id = ?", values ) await db.commit() return await self.get_experiment(exp_id) async def list_experiments( self, status: Optional[str] = None, limit: int = 50, offset: int = 0 ) -> List[Dict[str, Any]]: """查询实验列表""" async with aiosqlite.connect(self.db_path) as db: db.row_factory = aiosqlite.Row if status: query = """ SELECT * FROM experiments WHERE status = ? ORDER BY created_at DESC LIMIT ? OFFSET ? """ params = (status, limit, offset) else: query = """ SELECT * FROM experiments ORDER BY created_at DESC LIMIT ? OFFSET ? """ params = (limit, offset) async with db.execute(query, params) as cursor: rows = await cursor.fetchall() results = [] for row in rows: exp = dict(row) exp["is_half"] = bool(exp.get("is_half", 1)) # 简化列表,不包含完整的 stages results.append(exp) return results async def delete_experiment(self, exp_id: str) -> bool: """删除实验及其阶段""" async with aiosqlite.connect(self.db_path) as db: # 先删除阶段 await db.execute( "DELETE FROM stages WHERE experiment_id = ?", (exp_id,) ) # 再删除实验 cursor = await db.execute( "DELETE FROM experiments WHERE id = ?", (exp_id,) ) await db.commit() return cursor.rowcount > 0 # ============================================================ # Stage 操作 (Advanced Mode) # ============================================================ async def update_stage( self, exp_id: str, stage_type: str, updates: Dict[str, Any] ) -> Optional[Dict[str, Any]]: """更新阶段状态""" if not updates: return await self.get_stage(exp_id, stage_type) # 处理 JSON 字段 processed = {} for key, value in updates.items(): if key in ("config", "outputs") and isinstance(value, dict): processed[key] = json.dumps(value, ensure_ascii=False) elif key in ("started_at", "completed_at") and isinstance(value, datetime): processed[key] = value.isoformat() else: processed[key] = value async with aiosqlite.connect(self.db_path) as db: set_clause = ", ".join(f"{k} = ?" for k in processed.keys()) values = list(processed.values()) + [exp_id, stage_type] await db.execute( f"UPDATE stages SET {set_clause} WHERE experiment_id = ? AND stage_type = ?", values ) await db.commit() # 同时更新实验的 updated_at await self.update_experiment(exp_id, {}) return await self.get_stage(exp_id, stage_type) async def get_stage( self, exp_id: str, stage_type: str ) -> Optional[Dict[str, Any]]: """获取阶段状态""" async with aiosqlite.connect(self.db_path) as db: db.row_factory = aiosqlite.Row async with db.execute( "SELECT * FROM stages WHERE experiment_id = ? AND stage_type = ?", (exp_id, stage_type) ) as cursor: row = await cursor.fetchone() if not row: return None stage = dict(row) # 解析 JSON 字段 for json_field in ("config", "outputs"): if stage.get(json_field) and isinstance(stage[json_field], str): try: stage[json_field] = json.loads(stage[json_field]) except json.JSONDecodeError: stage[json_field] = None return stage async def get_all_stages(self, exp_id: str) -> List[Dict[str, Any]]: """获取实验的所有阶段状态""" async with aiosqlite.connect(self.db_path) as db: db.row_factory = aiosqlite.Row async with db.execute( "SELECT * FROM stages WHERE experiment_id = ? ORDER BY id", (exp_id,) ) as cursor: rows = await cursor.fetchall() results = [] for row in rows: stage = dict(row) # 解析 JSON 字段 for json_field in ("config", "outputs"): if stage.get(json_field) and isinstance(stage[json_field], str): try: stage[json_field] = json.loads(stage[json_field]) except json.JSONDecodeError: stage[json_field] = None results.append(stage) return results # ============================================================ # File 记录 # ============================================================ async def create_file_record(self, file_data: Dict[str, Any]) -> Dict[str, Any]: """创建文件记录""" file_id = file_data.get("id") or str(uuid.uuid4()) now = datetime.utcnow().isoformat() async with aiosqlite.connect(self.db_path) as db: await db.execute( '''INSERT INTO files (id, filename, content_type, size_bytes, purpose, duration_seconds, sample_rate, storage_path, uploaded_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)''', ( file_id, file_data["filename"], file_data.get("content_type"), file_data.get("size_bytes", 0), file_data.get("purpose", "training"), file_data.get("duration_seconds"), file_data.get("sample_rate"), file_data.get("storage_path"), file_data.get("uploaded_at", now), ) ) await db.commit() return await self.get_file_record(file_id) async def get_file_record(self, file_id: str) -> Optional[Dict[str, Any]]: """获取文件记录""" async with aiosqlite.connect(self.db_path) as db: db.row_factory = aiosqlite.Row async with db.execute( "SELECT * FROM files WHERE id = ?", (file_id,) ) as cursor: row = await cursor.fetchone() if row: return dict(row) return None async def delete_file_record(self, file_id: str) -> bool: """删除文件记录""" async with aiosqlite.connect(self.db_path) as db: cursor = await db.execute( "DELETE FROM files WHERE id = ?", (file_id,) ) await db.commit() return cursor.rowcount > 0 async def list_file_records( self, purpose: Optional[str] = None, limit: int = 50, offset: int = 0 ) -> List[Dict[str, Any]]: """查询文件记录列表""" async with aiosqlite.connect(self.db_path) as db: db.row_factory = aiosqlite.Row if purpose: query = """ SELECT * FROM files WHERE purpose = ? ORDER BY uploaded_at DESC LIMIT ? OFFSET ? """ params = (purpose, limit, offset) else: query = """ SELECT * FROM files ORDER BY uploaded_at DESC LIMIT ? OFFSET ? """ params = (limit, offset) async with db.execute(query, params) as cursor: rows = await cursor.fetchall() return [dict(row) for row in rows] async def count_file_records(self, purpose: Optional[str] = None) -> int: """统计文件记录数量""" async with aiosqlite.connect(self.db_path) as db: if purpose: async with db.execute( "SELECT COUNT(*) FROM files WHERE purpose = ?", (purpose,) ) as cursor: row = await cursor.fetchone() else: async with db.execute("SELECT COUNT(*) FROM files") as cursor: row = await cursor.fetchone() return row[0] if row else 0