| """Schema introspection and mutation helpers for synthetic database generation.""" |
|
|
| from __future__ import annotations |
|
|
| import sqlite3 |
| from dataclasses import dataclass |
| from pathlib import Path |
|
|
|
|
| @dataclass |
| class MutationResult: |
| """Result of applying a single mutation to a database.""" |
|
|
| mutation_name: str |
| tables_affected: list[str] |
| rows_added: int |
| success: bool |
|
|
|
|
| @dataclass |
| class TableSchema: |
| """Schema information for a single table.""" |
|
|
| name: str |
| columns: list[str] |
| pk_columns: list[str] |
| fk_columns: list[tuple[str, str, str]] |
|
|
|
|
| def get_table_schemas(db_path: str) -> list[TableSchema]: |
| """Extract table schema metadata (columns, PKs, and FKs) from a SQLite DB.""" |
|
|
| path = Path(db_path) |
| if not path.exists(): |
| raise sqlite3.OperationalError(f"Database does not exist: {db_path}") |
|
|
| try: |
| with sqlite3.connect(path) as connection: |
| cursor = connection.cursor() |
| cursor.execute( |
| """ |
| SELECT name |
| FROM sqlite_master |
| WHERE type = 'table' AND name NOT LIKE 'sqlite_%' |
| ORDER BY name |
| """ |
| ) |
| table_names = [row[0] for row in cursor.fetchall()] |
|
|
| schemas: list[TableSchema] = [] |
| for table_name in table_names: |
| pragma_name = table_name.replace('"', '""') |
|
|
| cursor.execute(f'PRAGMA table_info("{pragma_name}")') |
| table_info = cursor.fetchall() |
| columns = [row[1] for row in table_info] |
|
|
| pk_ordered = sorted( |
| ((int(row[5]), str(row[1])) for row in table_info if row[5]), |
| key=lambda item: item[0], |
| ) |
| pk_columns = [column_name for _, column_name in pk_ordered] |
|
|
| cursor.execute(f'PRAGMA foreign_key_list("{pragma_name}")') |
| fk_info = cursor.fetchall() |
| fk_columns = [ |
| (str(row[3]), str(row[2]), str(row[4])) |
| for row in fk_info |
| if row[3] and row[2] and row[4] |
| ] |
|
|
| schemas.append( |
| TableSchema( |
| name=table_name, |
| columns=columns, |
| pk_columns=pk_columns, |
| fk_columns=fk_columns, |
| ) |
| ) |
|
|
| return schemas |
| except sqlite3.DatabaseError as exc: |
| raise sqlite3.OperationalError(str(exc)) from exc |
|
|
|
|
| def detect_bridge_tables(schemas: list[TableSchema]) -> list[str]: |
| """Return tables that look like bridge tables (2 or more foreign keys).""" |
|
|
| return [schema.name for schema in schemas if len(schema.fk_columns) >= 2] |
|
|
|
|
| def _quote_identifier(identifier: str) -> str: |
| return f'"{identifier.replace(chr(34), chr(34) + chr(34))}"' |
|
|
|
|
| def _column_affinity(declared_type: str) -> str: |
| normalized = declared_type.upper() |
| if "INT" in normalized: |
| return "INTEGER" |
| if any(token in normalized for token in ("CHAR", "CLOB", "TEXT")): |
| return "TEXT" |
| if any(token in normalized for token in ("REAL", "FLOA", "DOUB")): |
| return "REAL" |
| if "BLOB" in normalized: |
| return "BLOB" |
| return "NUMERIC" |
|
|
|
|
| def inject_irrelevant_rows( |
| db_path: str, |
| schemas: list[TableSchema], |
| n_rows: int = 5, |
| ) -> MutationResult: |
| """Inject synthetic rows into non-bridge tables with integer primary keys.""" |
|
|
| if n_rows <= 0: |
| return MutationResult( |
| mutation_name="inject_irrelevant_rows", |
| tables_affected=[], |
| rows_added=0, |
| success=True, |
| ) |
|
|
| bridge_tables = set(detect_bridge_tables(schemas)) |
| rows_added = 0 |
| tables_affected: list[str] = [] |
|
|
| with sqlite3.connect(db_path) as connection: |
| cursor = connection.cursor() |
| for schema in schemas: |
| if schema.name in bridge_tables or len(schema.pk_columns) != 1: |
| continue |
|
|
| pk_column = schema.pk_columns[0] |
| pragma_table = schema.name.replace('"', '""') |
| cursor.execute(f'PRAGMA table_info("{pragma_table}")') |
| table_info = cursor.fetchall() |
| if not table_info: |
| continue |
|
|
| column_by_name = {str(row[1]): row for row in table_info} |
| pk_info = column_by_name.get(pk_column) |
| if pk_info is None: |
| continue |
| pk_affinity = _column_affinity(str(pk_info[2])) |
| if pk_affinity != "INTEGER": |
| continue |
|
|
| quoted_table = _quote_identifier(schema.name) |
| quoted_pk = _quote_identifier(pk_column) |
| cursor.execute(f"SELECT MAX({quoted_pk}) FROM {quoted_table}") |
| max_pk = cursor.fetchone()[0] |
| next_pk = int(max_pk) + 1 if max_pk is not None else 1 |
|
|
| fk_targets: dict[str, object] = {} |
| for fk_column, ref_table, ref_column in schema.fk_columns: |
| quoted_ref_table = _quote_identifier(ref_table) |
| quoted_ref_column = _quote_identifier(ref_column) |
| cursor.execute( |
| f"SELECT {quoted_ref_column} FROM {quoted_ref_table} LIMIT 1" |
| ) |
| result = cursor.fetchone() |
| if result is None: |
| fk_targets[fk_column] = None |
| else: |
| fk_targets[fk_column] = result[0] |
|
|
| integer_column_max: dict[str, int] = {} |
| for row in table_info: |
| column_name = str(row[1]) |
| if column_name == pk_column or column_name in fk_targets: |
| continue |
| affinity = _column_affinity(str(row[2])) |
| if affinity != "INTEGER": |
| continue |
| quoted_column = _quote_identifier(column_name) |
| cursor.execute(f"SELECT MAX({quoted_column}) FROM {quoted_table}") |
| column_max = cursor.fetchone()[0] |
| integer_column_max[column_name] = ( |
| int(column_max) if column_max is not None else 0 |
| ) |
|
|
| inserted_for_table = 0 |
| for row_index in range(n_rows): |
| row_values: list[object] = [] |
| skip_table = False |
| for row in table_info: |
| column_name = str(row[1]) |
| declared_type = str(row[2]) |
| not_null = bool(row[3]) |
| default_value = row[4] |
|
|
| if column_name == pk_column: |
| value: object = next_pk |
| next_pk += 1 |
| elif column_name in fk_targets: |
| value = fk_targets[column_name] |
| else: |
| affinity = _column_affinity(declared_type) |
| if affinity == "INTEGER": |
| value = ( |
| integer_column_max.get(column_name, 0) |
| + 1000 |
| + row_index |
| ) |
| elif affinity == "REAL": |
| value = float(row_index + 1) |
| elif affinity in ("TEXT", "NUMERIC"): |
| value = f"SYNTHETIC_{schema.name}_{column_name}_{row_index}" |
| else: |
| value = None |
|
|
| if value is None and not_null: |
| if default_value is not None: |
| value = default_value |
| else: |
| skip_table = True |
| break |
|
|
| row_values.append(value) |
|
|
| if skip_table: |
| inserted_for_table = 0 |
| break |
|
|
| quoted_columns = ", ".join( |
| _quote_identifier(str(row[1])) for row in table_info |
| ) |
| placeholders = ", ".join("?" for _ in table_info) |
| insert_sql = ( |
| f"INSERT INTO {quoted_table}" |
| f" ({quoted_columns})" |
| f" VALUES ({placeholders})" |
| ) |
| cursor.execute(insert_sql, row_values) |
| inserted_for_table += 1 |
|
|
| if inserted_for_table > 0: |
| tables_affected.append(schema.name) |
| rows_added += inserted_for_table |
|
|
| connection.commit() |
|
|
| return MutationResult( |
| mutation_name="inject_irrelevant_rows", |
| tables_affected=sorted(tables_affected), |
| rows_added=rows_added, |
| success=True, |
| ) |
|
|
|
|
| def remap_ids(db_path: str, schemas: list[TableSchema]) -> MutationResult: |
| """Remap integer primary keys and matching foreign keys with a bijection.""" |
|
|
| remap_plan: dict[str, tuple[str, dict[int, int]]] = {} |
| tables_affected: set[str] = set() |
| rows_updated = 0 |
|
|
| with sqlite3.connect(db_path) as connection: |
| cursor = connection.cursor() |
|
|
| for schema in schemas: |
| if len(schema.pk_columns) != 1: |
| continue |
|
|
| pk_column = schema.pk_columns[0] |
| quoted_table = _quote_identifier(schema.name) |
| quoted_pk = _quote_identifier(pk_column) |
|
|
| cursor.execute(f"PRAGMA table_info({quoted_table})") |
| table_info = cursor.fetchall() |
| column_by_name = {str(row[1]): row for row in table_info} |
| pk_info = column_by_name.get(pk_column) |
| if pk_info is None: |
| continue |
|
|
| if _column_affinity(str(pk_info[2])) != "INTEGER": |
| continue |
|
|
| cursor.execute( |
| f"SELECT {quoted_pk} FROM {quoted_table}" |
| f" WHERE {quoted_pk} IS NOT NULL" |
| f" ORDER BY {quoted_pk}" |
| ) |
| source_ids = [int(row[0]) for row in cursor.fetchall()] |
| if not source_ids: |
| continue |
|
|
| start_id = max(source_ids) + 1000 |
| mapping = { |
| source_id: start_id + index |
| for index, source_id in enumerate(source_ids) |
| } |
| remap_plan[schema.name] = (pk_column, mapping) |
|
|
| if not remap_plan: |
| return MutationResult( |
| mutation_name="remap_ids", |
| tables_affected=[], |
| rows_added=0, |
| success=True, |
| ) |
|
|
| try: |
| cursor.execute("PRAGMA foreign_keys = OFF") |
|
|
| for table_name, (pk_column, mapping) in remap_plan.items(): |
| quoted_table = _quote_identifier(table_name) |
| quoted_pk = _quote_identifier(pk_column) |
|
|
| case_parts = " ".join( |
| f"WHEN {old_id} THEN {new_id}" for old_id, new_id in mapping.items() |
| ) |
| where_values = ", ".join(str(old_id) for old_id in mapping) |
| update_sql = ( |
| f"UPDATE {quoted_table}" |
| f" SET {quoted_pk} = CASE {quoted_pk}" |
| f" {case_parts} ELSE {quoted_pk} END" |
| f" WHERE {quoted_pk} IN ({where_values})" |
| ) |
| cursor.execute(update_sql) |
|
|
| tables_affected.add(table_name) |
| rows_updated += len(mapping) |
|
|
| for child_schema in schemas: |
| quoted_child_table = _quote_identifier(child_schema.name) |
| for fk_column, ref_table, ref_column in child_schema.fk_columns: |
| parent_plan = remap_plan.get(ref_table) |
| if parent_plan is None: |
| continue |
|
|
| parent_pk_column, parent_mapping = parent_plan |
| if ref_column != parent_pk_column: |
| continue |
|
|
| quoted_fk = _quote_identifier(fk_column) |
| case_parts = " ".join( |
| f"WHEN {old_id} THEN {new_id}" |
| for old_id, new_id in parent_mapping.items() |
| ) |
| where_values = ", ".join(str(old_id) for old_id in parent_mapping) |
| update_sql = ( |
| f"UPDATE {quoted_child_table}" |
| f" SET {quoted_fk} = CASE {quoted_fk}" |
| f" {case_parts} ELSE {quoted_fk} END" |
| f" WHERE {quoted_fk} IN ({where_values})" |
| ) |
| cursor.execute(update_sql) |
|
|
| if cursor.rowcount > 0: |
| tables_affected.add(child_schema.name) |
|
|
| cursor.execute("PRAGMA foreign_keys = ON") |
| cursor.execute("PRAGMA foreign_key_check") |
| fk_violations = cursor.fetchall() |
| if fk_violations: |
| raise sqlite3.IntegrityError( |
| "Foreign key integrity check failed" |
| f" after ID remapping: {fk_violations[0]}" |
| ) |
|
|
| connection.commit() |
| except Exception: |
| connection.rollback() |
| cursor.execute("PRAGMA foreign_keys = ON") |
| raise |
|
|
| return MutationResult( |
| mutation_name="remap_ids", |
| tables_affected=sorted(tables_affected), |
| rows_added=rows_updated, |
| success=True, |
| ) |
|
|
|
|
| def duplicate_bridge_rows( |
| db_path: str, |
| schemas: list[TableSchema], |
| bridge_tables: list[str], |
| ) -> MutationResult: |
| """Duplicate bridge-table rows, skipping rows blocked by constraints.""" |
|
|
| if not bridge_tables: |
| return MutationResult( |
| mutation_name="duplicate_bridge_rows", |
| tables_affected=[], |
| rows_added=0, |
| success=True, |
| ) |
|
|
| schema_names = {schema.name for schema in schemas} |
| rows_added = 0 |
| tables_affected: list[str] = [] |
|
|
| with sqlite3.connect(db_path) as connection: |
| cursor = connection.cursor() |
|
|
| for table_name in bridge_tables: |
| if table_name not in schema_names: |
| continue |
|
|
| quoted_table = _quote_identifier(table_name) |
| cursor.execute(f"PRAGMA table_info({quoted_table})") |
| table_info = cursor.fetchall() |
| if not table_info: |
| continue |
|
|
| column_names = [str(row[1]) for row in table_info] |
| quoted_columns = ", ".join(_quote_identifier(name) for name in column_names) |
| placeholders = ", ".join("?" for _ in column_names) |
|
|
| cursor.execute(f"SELECT {quoted_columns} FROM {quoted_table}") |
| existing_rows = cursor.fetchall() |
| inserted_for_table = 0 |
|
|
| for row in existing_rows: |
| try: |
| insert_sql = ( |
| f"INSERT INTO {quoted_table}" |
| f" ({quoted_columns})" |
| f" VALUES ({placeholders})" |
| ) |
| cursor.execute(insert_sql, row) |
| inserted_for_table += 1 |
| except sqlite3.IntegrityError: |
| continue |
|
|
| if inserted_for_table > 0: |
| tables_affected.append(table_name) |
| rows_added += inserted_for_table |
|
|
| connection.commit() |
|
|
| return MutationResult( |
| mutation_name="duplicate_bridge_rows", |
| tables_affected=sorted(tables_affected), |
| rows_added=rows_added, |
| success=True, |
| ) |
|
|