|
|
""" |
|
|
SQLite 数据库适配器 |
|
|
|
|
|
基于 SQLite + aiosqlite 实现的数据库适配器,适用于 macOS 本地训练场景。 |
|
|
""" |
|
|
|
|
|
import json |
|
|
import sqlite3 |
|
|
import uuid |
|
|
from datetime import datetime |
|
|
from pathlib import Path |
|
|
from typing import Any, Dict, List, Optional |
|
|
|
|
|
import aiosqlite |
|
|
|
|
|
from project_config import settings |
|
|
from ..base import DatabaseAdapter |
|
|
from ...models.domain import Task, TaskStatus |
|
|
|
|
|
|
|
|
STAGE_TYPES = [ |
|
|
"audio_slice", |
|
|
"asr", |
|
|
"text_feature", |
|
|
"hubert_feature", |
|
|
"semantic_token", |
|
|
"sovits_train", |
|
|
"gpt_train", |
|
|
] |
|
|
|
|
|
|
|
|
class SQLiteAdapter(DatabaseAdapter): |
|
|
""" |
|
|
SQLite 数据库适配器 |
|
|
|
|
|
特点: |
|
|
1. 使用 aiosqlite 实现异步数据库操作 |
|
|
2. 支持 Task (Quick Mode) 和 Experiment (Advanced Mode) 管理 |
|
|
3. 自动初始化数据库表结构 |
|
|
|
|
|
表结构: |
|
|
- tasks: Quick Mode 任务 |
|
|
- experiments: Advanced Mode 实验 |
|
|
- stages: 实验阶段状态 |
|
|
- files: 文件记录 |
|
|
|
|
|
Example: |
|
|
>>> adapter = SQLiteAdapter() |
|
|
>>> task = Task(id="task-123", exp_name="my_voice", config={}) |
|
|
>>> await adapter.create_task(task) |
|
|
>>> task = await adapter.get_task("task-123") |
|
|
""" |
|
|
|
|
|
def __init__(self, db_path: Optional[str] = None): |
|
|
""" |
|
|
初始化 SQLite 适配器 |
|
|
|
|
|
Args: |
|
|
db_path: 数据库文件路径,默认使用 settings.SQLITE_PATH |
|
|
""" |
|
|
if db_path: |
|
|
self.db_path = db_path |
|
|
else: |
|
|
self.db_path = str(settings.SQLITE_PATH) |
|
|
|
|
|
|
|
|
Path(self.db_path).parent.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
self._init_db_sync() |
|
|
|
|
|
def _init_db_sync(self) -> None: |
|
|
"""同步初始化数据库表结构""" |
|
|
with sqlite3.connect(self.db_path) as conn: |
|
|
|
|
|
conn.execute(''' |
|
|
CREATE TABLE IF NOT EXISTS tasks ( |
|
|
id TEXT PRIMARY KEY, |
|
|
job_id TEXT, |
|
|
exp_name TEXT NOT NULL, |
|
|
status TEXT NOT NULL DEFAULT 'queued', |
|
|
config TEXT, |
|
|
current_stage TEXT, |
|
|
progress REAL DEFAULT 0, |
|
|
stage_progress REAL DEFAULT 0, |
|
|
message TEXT, |
|
|
error_message TEXT, |
|
|
created_at TEXT NOT NULL, |
|
|
started_at TEXT, |
|
|
completed_at TEXT |
|
|
) |
|
|
''') |
|
|
conn.execute('CREATE INDEX IF NOT EXISTS idx_tasks_status ON tasks(status)') |
|
|
conn.execute('CREATE INDEX IF NOT EXISTS idx_tasks_created ON tasks(created_at)') |
|
|
|
|
|
|
|
|
conn.execute(''' |
|
|
CREATE TABLE IF NOT EXISTS experiments ( |
|
|
id TEXT PRIMARY KEY, |
|
|
exp_name TEXT NOT NULL, |
|
|
version TEXT NOT NULL DEFAULT 'v2', |
|
|
exp_root TEXT DEFAULT 'logs', |
|
|
gpu_numbers TEXT DEFAULT '0', |
|
|
is_half INTEGER DEFAULT 1, |
|
|
audio_file_id TEXT, |
|
|
status TEXT NOT NULL DEFAULT 'created', |
|
|
created_at TEXT NOT NULL, |
|
|
updated_at TEXT |
|
|
) |
|
|
''') |
|
|
conn.execute('CREATE INDEX IF NOT EXISTS idx_experiments_status ON experiments(status)') |
|
|
conn.execute('CREATE INDEX IF NOT EXISTS idx_experiments_created ON experiments(created_at)') |
|
|
|
|
|
|
|
|
conn.execute(''' |
|
|
CREATE TABLE IF NOT EXISTS stages ( |
|
|
id TEXT PRIMARY KEY, |
|
|
experiment_id TEXT NOT NULL, |
|
|
stage_type TEXT NOT NULL, |
|
|
status TEXT DEFAULT 'pending', |
|
|
progress REAL DEFAULT 0, |
|
|
message TEXT, |
|
|
job_id TEXT, |
|
|
config TEXT, |
|
|
outputs TEXT, |
|
|
started_at TEXT, |
|
|
completed_at TEXT, |
|
|
error_message TEXT, |
|
|
FOREIGN KEY (experiment_id) REFERENCES experiments(id) ON DELETE CASCADE, |
|
|
UNIQUE (experiment_id, stage_type) |
|
|
) |
|
|
''') |
|
|
conn.execute('CREATE INDEX IF NOT EXISTS idx_stages_experiment ON stages(experiment_id)') |
|
|
conn.execute('CREATE INDEX IF NOT EXISTS idx_stages_status ON stages(status)') |
|
|
|
|
|
|
|
|
conn.execute(''' |
|
|
CREATE TABLE IF NOT EXISTS files ( |
|
|
id TEXT PRIMARY KEY, |
|
|
filename TEXT NOT NULL, |
|
|
content_type TEXT, |
|
|
size_bytes INTEGER DEFAULT 0, |
|
|
purpose TEXT DEFAULT 'training', |
|
|
duration_seconds REAL, |
|
|
sample_rate INTEGER, |
|
|
storage_path TEXT, |
|
|
uploaded_at TEXT NOT NULL |
|
|
) |
|
|
''') |
|
|
conn.execute('CREATE INDEX IF NOT EXISTS idx_files_purpose ON files(purpose)') |
|
|
conn.execute('CREATE INDEX IF NOT EXISTS idx_files_uploaded ON files(uploaded_at)') |
|
|
|
|
|
conn.commit() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def create_task(self, task: Task) -> Task: |
|
|
"""创建任务""" |
|
|
async with aiosqlite.connect(self.db_path) as db: |
|
|
await db.execute( |
|
|
'''INSERT INTO tasks |
|
|
(id, job_id, exp_name, status, config, current_stage, |
|
|
progress, stage_progress, message, error_message, |
|
|
created_at, started_at, completed_at) |
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)''', |
|
|
( |
|
|
task.id, |
|
|
task.job_id, |
|
|
task.exp_name, |
|
|
task.status.value if isinstance(task.status, TaskStatus) else task.status, |
|
|
json.dumps(task.config, ensure_ascii=False) if task.config else None, |
|
|
task.current_stage, |
|
|
task.progress, |
|
|
task.stage_progress, |
|
|
task.message, |
|
|
task.error_message, |
|
|
task.created_at.isoformat() if task.created_at else datetime.utcnow().isoformat(), |
|
|
task.started_at.isoformat() if task.started_at else None, |
|
|
task.completed_at.isoformat() if task.completed_at else None, |
|
|
) |
|
|
) |
|
|
await db.commit() |
|
|
|
|
|
return task |
|
|
|
|
|
async def get_task(self, task_id: str) -> Optional[Task]: |
|
|
"""获取任务""" |
|
|
async with aiosqlite.connect(self.db_path) as db: |
|
|
db.row_factory = aiosqlite.Row |
|
|
async with db.execute( |
|
|
"SELECT * FROM tasks WHERE id = ?", (task_id,) |
|
|
) as cursor: |
|
|
row = await cursor.fetchone() |
|
|
if row: |
|
|
return self._row_to_task(dict(row)) |
|
|
return None |
|
|
|
|
|
async def update_task(self, task_id: str, updates: Dict[str, Any]) -> Optional[Task]: |
|
|
"""更新任务""" |
|
|
if not updates: |
|
|
return await self.get_task(task_id) |
|
|
|
|
|
|
|
|
processed = {} |
|
|
for key, value in updates.items(): |
|
|
if key == "status" and isinstance(value, TaskStatus): |
|
|
processed[key] = value.value |
|
|
elif key == "config" and isinstance(value, dict): |
|
|
processed[key] = json.dumps(value, ensure_ascii=False) |
|
|
elif key in ("created_at", "started_at", "completed_at") and isinstance(value, datetime): |
|
|
processed[key] = value.isoformat() |
|
|
else: |
|
|
processed[key] = value |
|
|
|
|
|
async with aiosqlite.connect(self.db_path) as db: |
|
|
set_clause = ", ".join(f"{k} = ?" for k in processed.keys()) |
|
|
values = list(processed.values()) + [task_id] |
|
|
|
|
|
await db.execute( |
|
|
f"UPDATE tasks SET {set_clause} WHERE id = ?", |
|
|
values |
|
|
) |
|
|
await db.commit() |
|
|
|
|
|
return await self.get_task(task_id) |
|
|
|
|
|
async def list_tasks( |
|
|
self, |
|
|
status: Optional[str] = None, |
|
|
limit: int = 50, |
|
|
offset: int = 0 |
|
|
) -> List[Task]: |
|
|
"""查询任务列表""" |
|
|
async with aiosqlite.connect(self.db_path) as db: |
|
|
db.row_factory = aiosqlite.Row |
|
|
|
|
|
if status: |
|
|
query = """ |
|
|
SELECT * FROM tasks |
|
|
WHERE status = ? |
|
|
ORDER BY created_at DESC |
|
|
LIMIT ? OFFSET ? |
|
|
""" |
|
|
params = (status, limit, offset) |
|
|
else: |
|
|
query = """ |
|
|
SELECT * FROM tasks |
|
|
ORDER BY created_at DESC |
|
|
LIMIT ? OFFSET ? |
|
|
""" |
|
|
params = (limit, offset) |
|
|
|
|
|
async with db.execute(query, params) as cursor: |
|
|
rows = await cursor.fetchall() |
|
|
return [self._row_to_task(dict(row)) for row in rows] |
|
|
|
|
|
async def delete_task(self, task_id: str) -> bool: |
|
|
"""删除任务""" |
|
|
async with aiosqlite.connect(self.db_path) as db: |
|
|
cursor = await db.execute( |
|
|
"DELETE FROM tasks WHERE id = ?", (task_id,) |
|
|
) |
|
|
await db.commit() |
|
|
return cursor.rowcount > 0 |
|
|
|
|
|
async def count_tasks(self, status: Optional[str] = None) -> int: |
|
|
"""统计任务数量""" |
|
|
async with aiosqlite.connect(self.db_path) as db: |
|
|
if status: |
|
|
async with db.execute( |
|
|
"SELECT COUNT(*) FROM tasks WHERE status = ?", (status,) |
|
|
) as cursor: |
|
|
row = await cursor.fetchone() |
|
|
else: |
|
|
async with db.execute("SELECT COUNT(*) FROM tasks") as cursor: |
|
|
row = await cursor.fetchone() |
|
|
|
|
|
return row[0] if row else 0 |
|
|
|
|
|
async def get_task_by_exp_name(self, exp_name: str) -> Optional[Task]: |
|
|
"""根据实验名称获取任务""" |
|
|
async with aiosqlite.connect(self.db_path) as db: |
|
|
db.row_factory = aiosqlite.Row |
|
|
async with db.execute( |
|
|
"SELECT * FROM tasks WHERE exp_name = ? LIMIT 1", (exp_name,) |
|
|
) as cursor: |
|
|
row = await cursor.fetchone() |
|
|
if row: |
|
|
return self._row_to_task(dict(row)) |
|
|
return None |
|
|
|
|
|
def _row_to_task(self, row: Dict[str, Any]) -> Task: |
|
|
"""将数据库行转换为 Task 对象""" |
|
|
|
|
|
config = row.get("config") |
|
|
if config and isinstance(config, str): |
|
|
try: |
|
|
config = json.loads(config) |
|
|
except json.JSONDecodeError: |
|
|
config = {} |
|
|
|
|
|
return Task.from_dict({ |
|
|
"id": row["id"], |
|
|
"job_id": row.get("job_id"), |
|
|
"exp_name": row["exp_name"], |
|
|
"status": row.get("status", "queued"), |
|
|
"config": config or {}, |
|
|
"current_stage": row.get("current_stage"), |
|
|
"progress": row.get("progress", 0.0), |
|
|
"stage_progress": row.get("stage_progress", 0.0), |
|
|
"message": row.get("message"), |
|
|
"error_message": row.get("error_message"), |
|
|
"created_at": row.get("created_at"), |
|
|
"started_at": row.get("started_at"), |
|
|
"completed_at": row.get("completed_at"), |
|
|
}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def create_experiment(self, experiment: Dict[str, Any]) -> Dict[str, Any]: |
|
|
"""创建实验""" |
|
|
exp_id = experiment.get("id") or f"exp-{uuid.uuid4().hex[:8]}" |
|
|
now = datetime.utcnow().isoformat() |
|
|
|
|
|
async with aiosqlite.connect(self.db_path) as db: |
|
|
|
|
|
await db.execute( |
|
|
'''INSERT INTO experiments |
|
|
(id, exp_name, version, exp_root, gpu_numbers, is_half, |
|
|
audio_file_id, status, created_at, updated_at) |
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)''', |
|
|
( |
|
|
exp_id, |
|
|
experiment["exp_name"], |
|
|
experiment.get("version", "v2"), |
|
|
experiment.get("exp_root", "logs"), |
|
|
experiment.get("gpu_numbers", "0"), |
|
|
1 if experiment.get("is_half", True) else 0, |
|
|
experiment.get("audio_file_id"), |
|
|
experiment.get("status", "created"), |
|
|
now, |
|
|
now, |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
for stage_type in STAGE_TYPES: |
|
|
stage_id = f"{exp_id}-{stage_type}" |
|
|
await db.execute( |
|
|
'''INSERT INTO stages |
|
|
(id, experiment_id, stage_type, status) |
|
|
VALUES (?, ?, ?, 'pending')''', |
|
|
(stage_id, exp_id, stage_type) |
|
|
) |
|
|
|
|
|
await db.commit() |
|
|
|
|
|
return await self.get_experiment(exp_id) |
|
|
|
|
|
async def get_experiment(self, exp_id: str) -> Optional[Dict[str, Any]]: |
|
|
"""获取实验""" |
|
|
async with aiosqlite.connect(self.db_path) as db: |
|
|
db.row_factory = aiosqlite.Row |
|
|
|
|
|
|
|
|
async with db.execute( |
|
|
"SELECT * FROM experiments WHERE id = ?", (exp_id,) |
|
|
) as cursor: |
|
|
row = await cursor.fetchone() |
|
|
if not row: |
|
|
return None |
|
|
|
|
|
experiment = dict(row) |
|
|
experiment["is_half"] = bool(experiment.get("is_half", 1)) |
|
|
|
|
|
|
|
|
stages = {} |
|
|
async with db.execute( |
|
|
"SELECT * FROM stages WHERE experiment_id = ?", (exp_id,) |
|
|
) as cursor: |
|
|
stage_rows = await cursor.fetchall() |
|
|
for stage_row in stage_rows: |
|
|
stage = dict(stage_row) |
|
|
stage_type = stage["stage_type"] |
|
|
|
|
|
|
|
|
for json_field in ("config", "outputs"): |
|
|
if stage.get(json_field) and isinstance(stage[json_field], str): |
|
|
try: |
|
|
stage[json_field] = json.loads(stage[json_field]) |
|
|
except json.JSONDecodeError: |
|
|
stage[json_field] = None |
|
|
|
|
|
stages[stage_type] = stage |
|
|
|
|
|
experiment["stages"] = stages |
|
|
return experiment |
|
|
|
|
|
async def update_experiment( |
|
|
self, |
|
|
exp_id: str, |
|
|
updates: Dict[str, Any] |
|
|
) -> Optional[Dict[str, Any]]: |
|
|
"""更新实验""" |
|
|
if not updates: |
|
|
return await self.get_experiment(exp_id) |
|
|
|
|
|
|
|
|
processed = {} |
|
|
for key, value in updates.items(): |
|
|
if key == "is_half": |
|
|
processed[key] = 1 if value else 0 |
|
|
elif key == "updated_at" and isinstance(value, datetime): |
|
|
processed[key] = value.isoformat() |
|
|
elif key != "stages": |
|
|
processed[key] = value |
|
|
|
|
|
|
|
|
if "updated_at" not in processed: |
|
|
processed["updated_at"] = datetime.utcnow().isoformat() |
|
|
|
|
|
async with aiosqlite.connect(self.db_path) as db: |
|
|
if processed: |
|
|
set_clause = ", ".join(f"{k} = ?" for k in processed.keys()) |
|
|
values = list(processed.values()) + [exp_id] |
|
|
|
|
|
await db.execute( |
|
|
f"UPDATE experiments SET {set_clause} WHERE id = ?", |
|
|
values |
|
|
) |
|
|
await db.commit() |
|
|
|
|
|
return await self.get_experiment(exp_id) |
|
|
|
|
|
async def list_experiments( |
|
|
self, |
|
|
status: Optional[str] = None, |
|
|
limit: int = 50, |
|
|
offset: int = 0 |
|
|
) -> List[Dict[str, Any]]: |
|
|
"""查询实验列表""" |
|
|
async with aiosqlite.connect(self.db_path) as db: |
|
|
db.row_factory = aiosqlite.Row |
|
|
|
|
|
if status: |
|
|
query = """ |
|
|
SELECT * FROM experiments |
|
|
WHERE status = ? |
|
|
ORDER BY created_at DESC |
|
|
LIMIT ? OFFSET ? |
|
|
""" |
|
|
params = (status, limit, offset) |
|
|
else: |
|
|
query = """ |
|
|
SELECT * FROM experiments |
|
|
ORDER BY created_at DESC |
|
|
LIMIT ? OFFSET ? |
|
|
""" |
|
|
params = (limit, offset) |
|
|
|
|
|
async with db.execute(query, params) as cursor: |
|
|
rows = await cursor.fetchall() |
|
|
|
|
|
results = [] |
|
|
for row in rows: |
|
|
exp = dict(row) |
|
|
exp["is_half"] = bool(exp.get("is_half", 1)) |
|
|
|
|
|
results.append(exp) |
|
|
|
|
|
return results |
|
|
|
|
|
async def delete_experiment(self, exp_id: str) -> bool: |
|
|
"""删除实验及其阶段""" |
|
|
async with aiosqlite.connect(self.db_path) as db: |
|
|
|
|
|
await db.execute( |
|
|
"DELETE FROM stages WHERE experiment_id = ?", (exp_id,) |
|
|
) |
|
|
|
|
|
|
|
|
cursor = await db.execute( |
|
|
"DELETE FROM experiments WHERE id = ?", (exp_id,) |
|
|
) |
|
|
await db.commit() |
|
|
return cursor.rowcount > 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def update_stage( |
|
|
self, |
|
|
exp_id: str, |
|
|
stage_type: str, |
|
|
updates: Dict[str, Any] |
|
|
) -> Optional[Dict[str, Any]]: |
|
|
"""更新阶段状态""" |
|
|
if not updates: |
|
|
return await self.get_stage(exp_id, stage_type) |
|
|
|
|
|
|
|
|
processed = {} |
|
|
for key, value in updates.items(): |
|
|
if key in ("config", "outputs") and isinstance(value, dict): |
|
|
processed[key] = json.dumps(value, ensure_ascii=False) |
|
|
elif key in ("started_at", "completed_at") and isinstance(value, datetime): |
|
|
processed[key] = value.isoformat() |
|
|
else: |
|
|
processed[key] = value |
|
|
|
|
|
async with aiosqlite.connect(self.db_path) as db: |
|
|
set_clause = ", ".join(f"{k} = ?" for k in processed.keys()) |
|
|
values = list(processed.values()) + [exp_id, stage_type] |
|
|
|
|
|
await db.execute( |
|
|
f"UPDATE stages SET {set_clause} WHERE experiment_id = ? AND stage_type = ?", |
|
|
values |
|
|
) |
|
|
await db.commit() |
|
|
|
|
|
|
|
|
await self.update_experiment(exp_id, {}) |
|
|
|
|
|
return await self.get_stage(exp_id, stage_type) |
|
|
|
|
|
async def get_stage( |
|
|
self, |
|
|
exp_id: str, |
|
|
stage_type: str |
|
|
) -> Optional[Dict[str, Any]]: |
|
|
"""获取阶段状态""" |
|
|
async with aiosqlite.connect(self.db_path) as db: |
|
|
db.row_factory = aiosqlite.Row |
|
|
|
|
|
async with db.execute( |
|
|
"SELECT * FROM stages WHERE experiment_id = ? AND stage_type = ?", |
|
|
(exp_id, stage_type) |
|
|
) as cursor: |
|
|
row = await cursor.fetchone() |
|
|
if not row: |
|
|
return None |
|
|
|
|
|
stage = dict(row) |
|
|
|
|
|
|
|
|
for json_field in ("config", "outputs"): |
|
|
if stage.get(json_field) and isinstance(stage[json_field], str): |
|
|
try: |
|
|
stage[json_field] = json.loads(stage[json_field]) |
|
|
except json.JSONDecodeError: |
|
|
stage[json_field] = None |
|
|
|
|
|
return stage |
|
|
|
|
|
async def get_all_stages(self, exp_id: str) -> List[Dict[str, Any]]: |
|
|
"""获取实验的所有阶段状态""" |
|
|
async with aiosqlite.connect(self.db_path) as db: |
|
|
db.row_factory = aiosqlite.Row |
|
|
|
|
|
async with db.execute( |
|
|
"SELECT * FROM stages WHERE experiment_id = ? ORDER BY id", |
|
|
(exp_id,) |
|
|
) as cursor: |
|
|
rows = await cursor.fetchall() |
|
|
|
|
|
results = [] |
|
|
for row in rows: |
|
|
stage = dict(row) |
|
|
|
|
|
|
|
|
for json_field in ("config", "outputs"): |
|
|
if stage.get(json_field) and isinstance(stage[json_field], str): |
|
|
try: |
|
|
stage[json_field] = json.loads(stage[json_field]) |
|
|
except json.JSONDecodeError: |
|
|
stage[json_field] = None |
|
|
|
|
|
results.append(stage) |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def create_file_record(self, file_data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
"""创建文件记录""" |
|
|
file_id = file_data.get("id") or str(uuid.uuid4()) |
|
|
now = datetime.utcnow().isoformat() |
|
|
|
|
|
async with aiosqlite.connect(self.db_path) as db: |
|
|
await db.execute( |
|
|
'''INSERT INTO files |
|
|
(id, filename, content_type, size_bytes, purpose, |
|
|
duration_seconds, sample_rate, storage_path, uploaded_at) |
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)''', |
|
|
( |
|
|
file_id, |
|
|
file_data["filename"], |
|
|
file_data.get("content_type"), |
|
|
file_data.get("size_bytes", 0), |
|
|
file_data.get("purpose", "training"), |
|
|
file_data.get("duration_seconds"), |
|
|
file_data.get("sample_rate"), |
|
|
file_data.get("storage_path"), |
|
|
file_data.get("uploaded_at", now), |
|
|
) |
|
|
) |
|
|
await db.commit() |
|
|
|
|
|
return await self.get_file_record(file_id) |
|
|
|
|
|
async def get_file_record(self, file_id: str) -> Optional[Dict[str, Any]]: |
|
|
"""获取文件记录""" |
|
|
async with aiosqlite.connect(self.db_path) as db: |
|
|
db.row_factory = aiosqlite.Row |
|
|
|
|
|
async with db.execute( |
|
|
"SELECT * FROM files WHERE id = ?", (file_id,) |
|
|
) as cursor: |
|
|
row = await cursor.fetchone() |
|
|
if row: |
|
|
return dict(row) |
|
|
return None |
|
|
|
|
|
async def delete_file_record(self, file_id: str) -> bool: |
|
|
"""删除文件记录""" |
|
|
async with aiosqlite.connect(self.db_path) as db: |
|
|
cursor = await db.execute( |
|
|
"DELETE FROM files WHERE id = ?", (file_id,) |
|
|
) |
|
|
await db.commit() |
|
|
return cursor.rowcount > 0 |
|
|
|
|
|
async def list_file_records( |
|
|
self, |
|
|
purpose: Optional[str] = None, |
|
|
limit: int = 50, |
|
|
offset: int = 0 |
|
|
) -> List[Dict[str, Any]]: |
|
|
"""查询文件记录列表""" |
|
|
async with aiosqlite.connect(self.db_path) as db: |
|
|
db.row_factory = aiosqlite.Row |
|
|
|
|
|
if purpose: |
|
|
query = """ |
|
|
SELECT * FROM files |
|
|
WHERE purpose = ? |
|
|
ORDER BY uploaded_at DESC |
|
|
LIMIT ? OFFSET ? |
|
|
""" |
|
|
params = (purpose, limit, offset) |
|
|
else: |
|
|
query = """ |
|
|
SELECT * FROM files |
|
|
ORDER BY uploaded_at DESC |
|
|
LIMIT ? OFFSET ? |
|
|
""" |
|
|
params = (limit, offset) |
|
|
|
|
|
async with db.execute(query, params) as cursor: |
|
|
rows = await cursor.fetchall() |
|
|
return [dict(row) for row in rows] |
|
|
|
|
|
async def count_file_records(self, purpose: Optional[str] = None) -> int: |
|
|
"""统计文件记录数量""" |
|
|
async with aiosqlite.connect(self.db_path) as db: |
|
|
if purpose: |
|
|
async with db.execute( |
|
|
"SELECT COUNT(*) FROM files WHERE purpose = ?", (purpose,) |
|
|
) as cursor: |
|
|
row = await cursor.fetchone() |
|
|
else: |
|
|
async with db.execute("SELECT COUNT(*) FROM files") as cursor: |
|
|
row = await cursor.fetchone() |
|
|
|
|
|
return row[0] if row else 0 |
|
|
|