| | """ |
| | SQLite 数据库适配器 |
| | |
| | 基于 SQLite + aiosqlite 实现的数据库适配器,适用于 macOS 本地训练场景。 |
| | """ |
| |
|
| | import json |
| | import sqlite3 |
| | import uuid |
| | from datetime import datetime |
| | from pathlib import Path |
| | from typing import Any, Dict, List, Optional |
| |
|
| | import aiosqlite |
| |
|
| | from project_config import settings |
| | from ..base import DatabaseAdapter |
| | from ...models.domain import Task, TaskStatus |
| |
|
| | |
| | STAGE_TYPES = [ |
| | "audio_slice", |
| | "asr", |
| | "text_feature", |
| | "hubert_feature", |
| | "semantic_token", |
| | "sovits_train", |
| | "gpt_train", |
| | ] |
| |
|
| |
|
| | class SQLiteAdapter(DatabaseAdapter): |
| | """ |
| | SQLite 数据库适配器 |
| | |
| | 特点: |
| | 1. 使用 aiosqlite 实现异步数据库操作 |
| | 2. 支持 Task (Quick Mode) 和 Experiment (Advanced Mode) 管理 |
| | 3. 自动初始化数据库表结构 |
| | |
| | 表结构: |
| | - tasks: Quick Mode 任务 |
| | - experiments: Advanced Mode 实验 |
| | - stages: 实验阶段状态 |
| | - files: 文件记录 |
| | |
| | Example: |
| | >>> adapter = SQLiteAdapter() |
| | >>> task = Task(id="task-123", exp_name="my_voice", config={}) |
| | >>> await adapter.create_task(task) |
| | >>> task = await adapter.get_task("task-123") |
| | """ |
| | |
| | def __init__(self, db_path: Optional[str] = None): |
| | """ |
| | 初始化 SQLite 适配器 |
| | |
| | Args: |
| | db_path: 数据库文件路径,默认使用 settings.SQLITE_PATH |
| | """ |
| | if db_path: |
| | self.db_path = db_path |
| | else: |
| | self.db_path = str(settings.SQLITE_PATH) |
| | |
| | |
| | Path(self.db_path).parent.mkdir(parents=True, exist_ok=True) |
| | |
| | |
| | self._init_db_sync() |
| | |
| | def _init_db_sync(self) -> None: |
| | """同步初始化数据库表结构""" |
| | with sqlite3.connect(self.db_path) as conn: |
| | |
| | conn.execute(''' |
| | CREATE TABLE IF NOT EXISTS tasks ( |
| | id TEXT PRIMARY KEY, |
| | job_id TEXT, |
| | exp_name TEXT NOT NULL, |
| | status TEXT NOT NULL DEFAULT 'queued', |
| | config TEXT, |
| | current_stage TEXT, |
| | progress REAL DEFAULT 0, |
| | stage_progress REAL DEFAULT 0, |
| | message TEXT, |
| | error_message TEXT, |
| | created_at TEXT NOT NULL, |
| | started_at TEXT, |
| | completed_at TEXT |
| | ) |
| | ''') |
| | conn.execute('CREATE INDEX IF NOT EXISTS idx_tasks_status ON tasks(status)') |
| | conn.execute('CREATE INDEX IF NOT EXISTS idx_tasks_created ON tasks(created_at)') |
| | |
| | |
| | conn.execute(''' |
| | CREATE TABLE IF NOT EXISTS experiments ( |
| | id TEXT PRIMARY KEY, |
| | exp_name TEXT NOT NULL, |
| | version TEXT NOT NULL DEFAULT 'v2', |
| | exp_root TEXT DEFAULT 'logs', |
| | gpu_numbers TEXT DEFAULT '0', |
| | is_half INTEGER DEFAULT 1, |
| | audio_file_id TEXT, |
| | status TEXT NOT NULL DEFAULT 'created', |
| | created_at TEXT NOT NULL, |
| | updated_at TEXT |
| | ) |
| | ''') |
| | conn.execute('CREATE INDEX IF NOT EXISTS idx_experiments_status ON experiments(status)') |
| | conn.execute('CREATE INDEX IF NOT EXISTS idx_experiments_created ON experiments(created_at)') |
| | |
| | |
| | conn.execute(''' |
| | CREATE TABLE IF NOT EXISTS stages ( |
| | id TEXT PRIMARY KEY, |
| | experiment_id TEXT NOT NULL, |
| | stage_type TEXT NOT NULL, |
| | status TEXT DEFAULT 'pending', |
| | progress REAL DEFAULT 0, |
| | message TEXT, |
| | job_id TEXT, |
| | config TEXT, |
| | outputs TEXT, |
| | started_at TEXT, |
| | completed_at TEXT, |
| | error_message TEXT, |
| | FOREIGN KEY (experiment_id) REFERENCES experiments(id) ON DELETE CASCADE, |
| | UNIQUE (experiment_id, stage_type) |
| | ) |
| | ''') |
| | conn.execute('CREATE INDEX IF NOT EXISTS idx_stages_experiment ON stages(experiment_id)') |
| | conn.execute('CREATE INDEX IF NOT EXISTS idx_stages_status ON stages(status)') |
| | |
| | |
| | conn.execute(''' |
| | CREATE TABLE IF NOT EXISTS files ( |
| | id TEXT PRIMARY KEY, |
| | filename TEXT NOT NULL, |
| | content_type TEXT, |
| | size_bytes INTEGER DEFAULT 0, |
| | purpose TEXT DEFAULT 'training', |
| | duration_seconds REAL, |
| | sample_rate INTEGER, |
| | storage_path TEXT, |
| | uploaded_at TEXT NOT NULL |
| | ) |
| | ''') |
| | conn.execute('CREATE INDEX IF NOT EXISTS idx_files_purpose ON files(purpose)') |
| | conn.execute('CREATE INDEX IF NOT EXISTS idx_files_uploaded ON files(uploaded_at)') |
| | |
| | conn.commit() |
| | |
| | |
| | |
| | |
| | |
| | async def create_task(self, task: Task) -> Task: |
| | """创建任务""" |
| | async with aiosqlite.connect(self.db_path) as db: |
| | await db.execute( |
| | '''INSERT INTO tasks |
| | (id, job_id, exp_name, status, config, current_stage, |
| | progress, stage_progress, message, error_message, |
| | created_at, started_at, completed_at) |
| | VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)''', |
| | ( |
| | task.id, |
| | task.job_id, |
| | task.exp_name, |
| | task.status.value if isinstance(task.status, TaskStatus) else task.status, |
| | json.dumps(task.config, ensure_ascii=False) if task.config else None, |
| | task.current_stage, |
| | task.progress, |
| | task.stage_progress, |
| | task.message, |
| | task.error_message, |
| | task.created_at.isoformat() if task.created_at else datetime.utcnow().isoformat(), |
| | task.started_at.isoformat() if task.started_at else None, |
| | task.completed_at.isoformat() if task.completed_at else None, |
| | ) |
| | ) |
| | await db.commit() |
| | |
| | return task |
| | |
| | async def get_task(self, task_id: str) -> Optional[Task]: |
| | """获取任务""" |
| | async with aiosqlite.connect(self.db_path) as db: |
| | db.row_factory = aiosqlite.Row |
| | async with db.execute( |
| | "SELECT * FROM tasks WHERE id = ?", (task_id,) |
| | ) as cursor: |
| | row = await cursor.fetchone() |
| | if row: |
| | return self._row_to_task(dict(row)) |
| | return None |
| | |
| | async def update_task(self, task_id: str, updates: Dict[str, Any]) -> Optional[Task]: |
| | """更新任务""" |
| | if not updates: |
| | return await self.get_task(task_id) |
| | |
| | |
| | processed = {} |
| | for key, value in updates.items(): |
| | if key == "status" and isinstance(value, TaskStatus): |
| | processed[key] = value.value |
| | elif key == "config" and isinstance(value, dict): |
| | processed[key] = json.dumps(value, ensure_ascii=False) |
| | elif key in ("created_at", "started_at", "completed_at") and isinstance(value, datetime): |
| | processed[key] = value.isoformat() |
| | else: |
| | processed[key] = value |
| | |
| | async with aiosqlite.connect(self.db_path) as db: |
| | set_clause = ", ".join(f"{k} = ?" for k in processed.keys()) |
| | values = list(processed.values()) + [task_id] |
| | |
| | await db.execute( |
| | f"UPDATE tasks SET {set_clause} WHERE id = ?", |
| | values |
| | ) |
| | await db.commit() |
| | |
| | return await self.get_task(task_id) |
| | |
| | async def list_tasks( |
| | self, |
| | status: Optional[str] = None, |
| | limit: int = 50, |
| | offset: int = 0 |
| | ) -> List[Task]: |
| | """查询任务列表""" |
| | async with aiosqlite.connect(self.db_path) as db: |
| | db.row_factory = aiosqlite.Row |
| | |
| | if status: |
| | query = """ |
| | SELECT * FROM tasks |
| | WHERE status = ? |
| | ORDER BY created_at DESC |
| | LIMIT ? OFFSET ? |
| | """ |
| | params = (status, limit, offset) |
| | else: |
| | query = """ |
| | SELECT * FROM tasks |
| | ORDER BY created_at DESC |
| | LIMIT ? OFFSET ? |
| | """ |
| | params = (limit, offset) |
| | |
| | async with db.execute(query, params) as cursor: |
| | rows = await cursor.fetchall() |
| | return [self._row_to_task(dict(row)) for row in rows] |
| | |
| | async def delete_task(self, task_id: str) -> bool: |
| | """删除任务""" |
| | async with aiosqlite.connect(self.db_path) as db: |
| | cursor = await db.execute( |
| | "DELETE FROM tasks WHERE id = ?", (task_id,) |
| | ) |
| | await db.commit() |
| | return cursor.rowcount > 0 |
| | |
| | async def count_tasks(self, status: Optional[str] = None) -> int: |
| | """统计任务数量""" |
| | async with aiosqlite.connect(self.db_path) as db: |
| | if status: |
| | async with db.execute( |
| | "SELECT COUNT(*) FROM tasks WHERE status = ?", (status,) |
| | ) as cursor: |
| | row = await cursor.fetchone() |
| | else: |
| | async with db.execute("SELECT COUNT(*) FROM tasks") as cursor: |
| | row = await cursor.fetchone() |
| | |
| | return row[0] if row else 0 |
| | |
| | async def get_task_by_exp_name(self, exp_name: str) -> Optional[Task]: |
| | """根据实验名称获取任务""" |
| | async with aiosqlite.connect(self.db_path) as db: |
| | db.row_factory = aiosqlite.Row |
| | async with db.execute( |
| | "SELECT * FROM tasks WHERE exp_name = ? LIMIT 1", (exp_name,) |
| | ) as cursor: |
| | row = await cursor.fetchone() |
| | if row: |
| | return self._row_to_task(dict(row)) |
| | return None |
| | |
| | def _row_to_task(self, row: Dict[str, Any]) -> Task: |
| | """将数据库行转换为 Task 对象""" |
| | |
| | config = row.get("config") |
| | if config and isinstance(config, str): |
| | try: |
| | config = json.loads(config) |
| | except json.JSONDecodeError: |
| | config = {} |
| | |
| | return Task.from_dict({ |
| | "id": row["id"], |
| | "job_id": row.get("job_id"), |
| | "exp_name": row["exp_name"], |
| | "status": row.get("status", "queued"), |
| | "config": config or {}, |
| | "current_stage": row.get("current_stage"), |
| | "progress": row.get("progress", 0.0), |
| | "stage_progress": row.get("stage_progress", 0.0), |
| | "message": row.get("message"), |
| | "error_message": row.get("error_message"), |
| | "created_at": row.get("created_at"), |
| | "started_at": row.get("started_at"), |
| | "completed_at": row.get("completed_at"), |
| | }) |
| | |
| | |
| | |
| | |
| | |
| | async def create_experiment(self, experiment: Dict[str, Any]) -> Dict[str, Any]: |
| | """创建实验""" |
| | exp_id = experiment.get("id") or f"exp-{uuid.uuid4().hex[:8]}" |
| | now = datetime.utcnow().isoformat() |
| | |
| | async with aiosqlite.connect(self.db_path) as db: |
| | |
| | await db.execute( |
| | '''INSERT INTO experiments |
| | (id, exp_name, version, exp_root, gpu_numbers, is_half, |
| | audio_file_id, status, created_at, updated_at) |
| | VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)''', |
| | ( |
| | exp_id, |
| | experiment["exp_name"], |
| | experiment.get("version", "v2"), |
| | experiment.get("exp_root", "logs"), |
| | experiment.get("gpu_numbers", "0"), |
| | 1 if experiment.get("is_half", True) else 0, |
| | experiment.get("audio_file_id"), |
| | experiment.get("status", "created"), |
| | now, |
| | now, |
| | ) |
| | ) |
| | |
| | |
| | for stage_type in STAGE_TYPES: |
| | stage_id = f"{exp_id}-{stage_type}" |
| | await db.execute( |
| | '''INSERT INTO stages |
| | (id, experiment_id, stage_type, status) |
| | VALUES (?, ?, ?, 'pending')''', |
| | (stage_id, exp_id, stage_type) |
| | ) |
| | |
| | await db.commit() |
| | |
| | return await self.get_experiment(exp_id) |
| | |
| | async def get_experiment(self, exp_id: str) -> Optional[Dict[str, Any]]: |
| | """获取实验""" |
| | async with aiosqlite.connect(self.db_path) as db: |
| | db.row_factory = aiosqlite.Row |
| | |
| | |
| | async with db.execute( |
| | "SELECT * FROM experiments WHERE id = ?", (exp_id,) |
| | ) as cursor: |
| | row = await cursor.fetchone() |
| | if not row: |
| | return None |
| | |
| | experiment = dict(row) |
| | experiment["is_half"] = bool(experiment.get("is_half", 1)) |
| | |
| | |
| | stages = {} |
| | async with db.execute( |
| | "SELECT * FROM stages WHERE experiment_id = ?", (exp_id,) |
| | ) as cursor: |
| | stage_rows = await cursor.fetchall() |
| | for stage_row in stage_rows: |
| | stage = dict(stage_row) |
| | stage_type = stage["stage_type"] |
| | |
| | |
| | for json_field in ("config", "outputs"): |
| | if stage.get(json_field) and isinstance(stage[json_field], str): |
| | try: |
| | stage[json_field] = json.loads(stage[json_field]) |
| | except json.JSONDecodeError: |
| | stage[json_field] = None |
| | |
| | stages[stage_type] = stage |
| | |
| | experiment["stages"] = stages |
| | return experiment |
| | |
| | async def update_experiment( |
| | self, |
| | exp_id: str, |
| | updates: Dict[str, Any] |
| | ) -> Optional[Dict[str, Any]]: |
| | """更新实验""" |
| | if not updates: |
| | return await self.get_experiment(exp_id) |
| | |
| | |
| | processed = {} |
| | for key, value in updates.items(): |
| | if key == "is_half": |
| | processed[key] = 1 if value else 0 |
| | elif key == "updated_at" and isinstance(value, datetime): |
| | processed[key] = value.isoformat() |
| | elif key != "stages": |
| | processed[key] = value |
| | |
| | |
| | if "updated_at" not in processed: |
| | processed["updated_at"] = datetime.utcnow().isoformat() |
| | |
| | async with aiosqlite.connect(self.db_path) as db: |
| | if processed: |
| | set_clause = ", ".join(f"{k} = ?" for k in processed.keys()) |
| | values = list(processed.values()) + [exp_id] |
| | |
| | await db.execute( |
| | f"UPDATE experiments SET {set_clause} WHERE id = ?", |
| | values |
| | ) |
| | await db.commit() |
| | |
| | return await self.get_experiment(exp_id) |
| | |
| | async def list_experiments( |
| | self, |
| | status: Optional[str] = None, |
| | limit: int = 50, |
| | offset: int = 0 |
| | ) -> List[Dict[str, Any]]: |
| | """查询实验列表""" |
| | async with aiosqlite.connect(self.db_path) as db: |
| | db.row_factory = aiosqlite.Row |
| | |
| | if status: |
| | query = """ |
| | SELECT * FROM experiments |
| | WHERE status = ? |
| | ORDER BY created_at DESC |
| | LIMIT ? OFFSET ? |
| | """ |
| | params = (status, limit, offset) |
| | else: |
| | query = """ |
| | SELECT * FROM experiments |
| | ORDER BY created_at DESC |
| | LIMIT ? OFFSET ? |
| | """ |
| | params = (limit, offset) |
| | |
| | async with db.execute(query, params) as cursor: |
| | rows = await cursor.fetchall() |
| | |
| | results = [] |
| | for row in rows: |
| | exp = dict(row) |
| | exp["is_half"] = bool(exp.get("is_half", 1)) |
| | |
| | results.append(exp) |
| | |
| | return results |
| | |
| | async def delete_experiment(self, exp_id: str) -> bool: |
| | """删除实验及其阶段""" |
| | async with aiosqlite.connect(self.db_path) as db: |
| | |
| | await db.execute( |
| | "DELETE FROM stages WHERE experiment_id = ?", (exp_id,) |
| | ) |
| | |
| | |
| | cursor = await db.execute( |
| | "DELETE FROM experiments WHERE id = ?", (exp_id,) |
| | ) |
| | await db.commit() |
| | return cursor.rowcount > 0 |
| | |
| | |
| | |
| | |
| | |
| | async def update_stage( |
| | self, |
| | exp_id: str, |
| | stage_type: str, |
| | updates: Dict[str, Any] |
| | ) -> Optional[Dict[str, Any]]: |
| | """更新阶段状态""" |
| | if not updates: |
| | return await self.get_stage(exp_id, stage_type) |
| | |
| | |
| | processed = {} |
| | for key, value in updates.items(): |
| | if key in ("config", "outputs") and isinstance(value, dict): |
| | processed[key] = json.dumps(value, ensure_ascii=False) |
| | elif key in ("started_at", "completed_at") and isinstance(value, datetime): |
| | processed[key] = value.isoformat() |
| | else: |
| | processed[key] = value |
| | |
| | async with aiosqlite.connect(self.db_path) as db: |
| | set_clause = ", ".join(f"{k} = ?" for k in processed.keys()) |
| | values = list(processed.values()) + [exp_id, stage_type] |
| | |
| | await db.execute( |
| | f"UPDATE stages SET {set_clause} WHERE experiment_id = ? AND stage_type = ?", |
| | values |
| | ) |
| | await db.commit() |
| | |
| | |
| | await self.update_experiment(exp_id, {}) |
| | |
| | return await self.get_stage(exp_id, stage_type) |
| | |
| | async def get_stage( |
| | self, |
| | exp_id: str, |
| | stage_type: str |
| | ) -> Optional[Dict[str, Any]]: |
| | """获取阶段状态""" |
| | async with aiosqlite.connect(self.db_path) as db: |
| | db.row_factory = aiosqlite.Row |
| | |
| | async with db.execute( |
| | "SELECT * FROM stages WHERE experiment_id = ? AND stage_type = ?", |
| | (exp_id, stage_type) |
| | ) as cursor: |
| | row = await cursor.fetchone() |
| | if not row: |
| | return None |
| | |
| | stage = dict(row) |
| | |
| | |
| | for json_field in ("config", "outputs"): |
| | if stage.get(json_field) and isinstance(stage[json_field], str): |
| | try: |
| | stage[json_field] = json.loads(stage[json_field]) |
| | except json.JSONDecodeError: |
| | stage[json_field] = None |
| | |
| | return stage |
| | |
| | async def get_all_stages(self, exp_id: str) -> List[Dict[str, Any]]: |
| | """获取实验的所有阶段状态""" |
| | async with aiosqlite.connect(self.db_path) as db: |
| | db.row_factory = aiosqlite.Row |
| | |
| | async with db.execute( |
| | "SELECT * FROM stages WHERE experiment_id = ? ORDER BY id", |
| | (exp_id,) |
| | ) as cursor: |
| | rows = await cursor.fetchall() |
| | |
| | results = [] |
| | for row in rows: |
| | stage = dict(row) |
| | |
| | |
| | for json_field in ("config", "outputs"): |
| | if stage.get(json_field) and isinstance(stage[json_field], str): |
| | try: |
| | stage[json_field] = json.loads(stage[json_field]) |
| | except json.JSONDecodeError: |
| | stage[json_field] = None |
| | |
| | results.append(stage) |
| | |
| | return results |
| | |
| | |
| | |
| | |
| | |
| | async def create_file_record(self, file_data: Dict[str, Any]) -> Dict[str, Any]: |
| | """创建文件记录""" |
| | file_id = file_data.get("id") or str(uuid.uuid4()) |
| | now = datetime.utcnow().isoformat() |
| | |
| | async with aiosqlite.connect(self.db_path) as db: |
| | await db.execute( |
| | '''INSERT INTO files |
| | (id, filename, content_type, size_bytes, purpose, |
| | duration_seconds, sample_rate, storage_path, uploaded_at) |
| | VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)''', |
| | ( |
| | file_id, |
| | file_data["filename"], |
| | file_data.get("content_type"), |
| | file_data.get("size_bytes", 0), |
| | file_data.get("purpose", "training"), |
| | file_data.get("duration_seconds"), |
| | file_data.get("sample_rate"), |
| | file_data.get("storage_path"), |
| | file_data.get("uploaded_at", now), |
| | ) |
| | ) |
| | await db.commit() |
| | |
| | return await self.get_file_record(file_id) |
| | |
| | async def get_file_record(self, file_id: str) -> Optional[Dict[str, Any]]: |
| | """获取文件记录""" |
| | async with aiosqlite.connect(self.db_path) as db: |
| | db.row_factory = aiosqlite.Row |
| | |
| | async with db.execute( |
| | "SELECT * FROM files WHERE id = ?", (file_id,) |
| | ) as cursor: |
| | row = await cursor.fetchone() |
| | if row: |
| | return dict(row) |
| | return None |
| | |
| | async def delete_file_record(self, file_id: str) -> bool: |
| | """删除文件记录""" |
| | async with aiosqlite.connect(self.db_path) as db: |
| | cursor = await db.execute( |
| | "DELETE FROM files WHERE id = ?", (file_id,) |
| | ) |
| | await db.commit() |
| | return cursor.rowcount > 0 |
| | |
| | async def list_file_records( |
| | self, |
| | purpose: Optional[str] = None, |
| | limit: int = 50, |
| | offset: int = 0 |
| | ) -> List[Dict[str, Any]]: |
| | """查询文件记录列表""" |
| | async with aiosqlite.connect(self.db_path) as db: |
| | db.row_factory = aiosqlite.Row |
| | |
| | if purpose: |
| | query = """ |
| | SELECT * FROM files |
| | WHERE purpose = ? |
| | ORDER BY uploaded_at DESC |
| | LIMIT ? OFFSET ? |
| | """ |
| | params = (purpose, limit, offset) |
| | else: |
| | query = """ |
| | SELECT * FROM files |
| | ORDER BY uploaded_at DESC |
| | LIMIT ? OFFSET ? |
| | """ |
| | params = (limit, offset) |
| | |
| | async with db.execute(query, params) as cursor: |
| | rows = await cursor.fetchall() |
| | return [dict(row) for row in rows] |
| | |
| | async def count_file_records(self, purpose: Optional[str] = None) -> int: |
| | """统计文件记录数量""" |
| | async with aiosqlite.connect(self.db_path) as db: |
| | if purpose: |
| | async with db.execute( |
| | "SELECT COUNT(*) FROM files WHERE purpose = ?", (purpose,) |
| | ) as cursor: |
| | row = await cursor.fetchone() |
| | else: |
| | async with db.execute("SELECT COUNT(*) FROM files") as cursor: |
| | row = await cursor.fetchone() |
| | |
| | return row[0] if row else 0 |
| |
|