sql_env / server /synthetic /mutations.py
hjerpe's picture
Upload folder using huggingface_hub
9e64e71 verified
"""Schema introspection and mutation helpers for synthetic database generation."""
from __future__ import annotations
import sqlite3
from dataclasses import dataclass
from pathlib import Path
@dataclass
class MutationResult:
"""Result of applying a single mutation to a database."""
mutation_name: str
tables_affected: list[str]
rows_added: int
success: bool
@dataclass
class TableSchema:
"""Schema information for a single table."""
name: str
columns: list[str]
pk_columns: list[str]
fk_columns: list[tuple[str, str, str]]
def get_table_schemas(db_path: str) -> list[TableSchema]:
"""Extract table schema metadata (columns, PKs, and FKs) from a SQLite DB."""
path = Path(db_path)
if not path.exists():
raise sqlite3.OperationalError(f"Database does not exist: {db_path}")
try:
with sqlite3.connect(path) as connection:
cursor = connection.cursor()
cursor.execute(
"""
SELECT name
FROM sqlite_master
WHERE type = 'table' AND name NOT LIKE 'sqlite_%'
ORDER BY name
"""
)
table_names = [row[0] for row in cursor.fetchall()]
schemas: list[TableSchema] = []
for table_name in table_names:
pragma_name = table_name.replace('"', '""')
cursor.execute(f'PRAGMA table_info("{pragma_name}")')
table_info = cursor.fetchall()
columns = [row[1] for row in table_info]
pk_ordered = sorted(
((int(row[5]), str(row[1])) for row in table_info if row[5]),
key=lambda item: item[0],
)
pk_columns = [column_name for _, column_name in pk_ordered]
cursor.execute(f'PRAGMA foreign_key_list("{pragma_name}")')
fk_info = cursor.fetchall()
fk_columns = [
(str(row[3]), str(row[2]), str(row[4]))
for row in fk_info
if row[3] and row[2] and row[4]
]
schemas.append(
TableSchema(
name=table_name,
columns=columns,
pk_columns=pk_columns,
fk_columns=fk_columns,
)
)
return schemas
except sqlite3.DatabaseError as exc:
raise sqlite3.OperationalError(str(exc)) from exc
def detect_bridge_tables(schemas: list[TableSchema]) -> list[str]:
"""Return tables that look like bridge tables (2 or more foreign keys)."""
return [schema.name for schema in schemas if len(schema.fk_columns) >= 2]
def _quote_identifier(identifier: str) -> str:
return f'"{identifier.replace(chr(34), chr(34) + chr(34))}"'
def _column_affinity(declared_type: str) -> str:
normalized = declared_type.upper()
if "INT" in normalized:
return "INTEGER"
if any(token in normalized for token in ("CHAR", "CLOB", "TEXT")):
return "TEXT"
if any(token in normalized for token in ("REAL", "FLOA", "DOUB")):
return "REAL"
if "BLOB" in normalized:
return "BLOB"
return "NUMERIC"
def inject_irrelevant_rows(
db_path: str,
schemas: list[TableSchema],
n_rows: int = 5,
) -> MutationResult:
"""Inject synthetic rows into non-bridge tables with integer primary keys."""
if n_rows <= 0:
return MutationResult(
mutation_name="inject_irrelevant_rows",
tables_affected=[],
rows_added=0,
success=True,
)
bridge_tables = set(detect_bridge_tables(schemas))
rows_added = 0
tables_affected: list[str] = []
with sqlite3.connect(db_path) as connection:
cursor = connection.cursor()
for schema in schemas:
if schema.name in bridge_tables or len(schema.pk_columns) != 1:
continue
pk_column = schema.pk_columns[0]
pragma_table = schema.name.replace('"', '""')
cursor.execute(f'PRAGMA table_info("{pragma_table}")')
table_info = cursor.fetchall()
if not table_info:
continue
column_by_name = {str(row[1]): row for row in table_info}
pk_info = column_by_name.get(pk_column)
if pk_info is None:
continue
pk_affinity = _column_affinity(str(pk_info[2]))
if pk_affinity != "INTEGER":
continue
quoted_table = _quote_identifier(schema.name)
quoted_pk = _quote_identifier(pk_column)
cursor.execute(f"SELECT MAX({quoted_pk}) FROM {quoted_table}")
max_pk = cursor.fetchone()[0]
next_pk = int(max_pk) + 1 if max_pk is not None else 1
fk_targets: dict[str, object] = {}
for fk_column, ref_table, ref_column in schema.fk_columns:
quoted_ref_table = _quote_identifier(ref_table)
quoted_ref_column = _quote_identifier(ref_column)
cursor.execute(
f"SELECT {quoted_ref_column} FROM {quoted_ref_table} LIMIT 1"
)
result = cursor.fetchone()
if result is None:
fk_targets[fk_column] = None
else:
fk_targets[fk_column] = result[0]
integer_column_max: dict[str, int] = {}
for row in table_info:
column_name = str(row[1])
if column_name == pk_column or column_name in fk_targets:
continue
affinity = _column_affinity(str(row[2]))
if affinity != "INTEGER":
continue
quoted_column = _quote_identifier(column_name)
cursor.execute(f"SELECT MAX({quoted_column}) FROM {quoted_table}")
column_max = cursor.fetchone()[0]
integer_column_max[column_name] = (
int(column_max) if column_max is not None else 0
)
inserted_for_table = 0
for row_index in range(n_rows):
row_values: list[object] = []
skip_table = False
for row in table_info:
column_name = str(row[1])
declared_type = str(row[2])
not_null = bool(row[3])
default_value = row[4]
if column_name == pk_column:
value: object = next_pk
next_pk += 1
elif column_name in fk_targets:
value = fk_targets[column_name]
else:
affinity = _column_affinity(declared_type)
if affinity == "INTEGER":
value = (
integer_column_max.get(column_name, 0)
+ 1000
+ row_index
)
elif affinity == "REAL":
value = float(row_index + 1)
elif affinity in ("TEXT", "NUMERIC"):
value = f"SYNTHETIC_{schema.name}_{column_name}_{row_index}"
else:
value = None
if value is None and not_null:
if default_value is not None:
value = default_value
else:
skip_table = True
break
row_values.append(value)
if skip_table:
inserted_for_table = 0
break
quoted_columns = ", ".join(
_quote_identifier(str(row[1])) for row in table_info
)
placeholders = ", ".join("?" for _ in table_info)
insert_sql = (
f"INSERT INTO {quoted_table}"
f" ({quoted_columns})"
f" VALUES ({placeholders})"
)
cursor.execute(insert_sql, row_values)
inserted_for_table += 1
if inserted_for_table > 0:
tables_affected.append(schema.name)
rows_added += inserted_for_table
connection.commit()
return MutationResult(
mutation_name="inject_irrelevant_rows",
tables_affected=sorted(tables_affected),
rows_added=rows_added,
success=True,
)
def remap_ids(db_path: str, schemas: list[TableSchema]) -> MutationResult:
"""Remap integer primary keys and matching foreign keys with a bijection."""
remap_plan: dict[str, tuple[str, dict[int, int]]] = {}
tables_affected: set[str] = set()
rows_updated = 0
with sqlite3.connect(db_path) as connection:
cursor = connection.cursor()
for schema in schemas:
if len(schema.pk_columns) != 1:
continue
pk_column = schema.pk_columns[0]
quoted_table = _quote_identifier(schema.name)
quoted_pk = _quote_identifier(pk_column)
cursor.execute(f"PRAGMA table_info({quoted_table})")
table_info = cursor.fetchall()
column_by_name = {str(row[1]): row for row in table_info}
pk_info = column_by_name.get(pk_column)
if pk_info is None:
continue
if _column_affinity(str(pk_info[2])) != "INTEGER":
continue
cursor.execute(
f"SELECT {quoted_pk} FROM {quoted_table}"
f" WHERE {quoted_pk} IS NOT NULL"
f" ORDER BY {quoted_pk}"
)
source_ids = [int(row[0]) for row in cursor.fetchall()]
if not source_ids:
continue
start_id = max(source_ids) + 1000
mapping = {
source_id: start_id + index
for index, source_id in enumerate(source_ids)
}
remap_plan[schema.name] = (pk_column, mapping)
if not remap_plan:
return MutationResult(
mutation_name="remap_ids",
tables_affected=[],
rows_added=0,
success=True,
)
try:
cursor.execute("PRAGMA foreign_keys = OFF")
for table_name, (pk_column, mapping) in remap_plan.items():
quoted_table = _quote_identifier(table_name)
quoted_pk = _quote_identifier(pk_column)
case_parts = " ".join(
f"WHEN {old_id} THEN {new_id}" for old_id, new_id in mapping.items()
)
where_values = ", ".join(str(old_id) for old_id in mapping)
update_sql = (
f"UPDATE {quoted_table}"
f" SET {quoted_pk} = CASE {quoted_pk}"
f" {case_parts} ELSE {quoted_pk} END"
f" WHERE {quoted_pk} IN ({where_values})"
)
cursor.execute(update_sql)
tables_affected.add(table_name)
rows_updated += len(mapping)
for child_schema in schemas:
quoted_child_table = _quote_identifier(child_schema.name)
for fk_column, ref_table, ref_column in child_schema.fk_columns:
parent_plan = remap_plan.get(ref_table)
if parent_plan is None:
continue
parent_pk_column, parent_mapping = parent_plan
if ref_column != parent_pk_column:
continue
quoted_fk = _quote_identifier(fk_column)
case_parts = " ".join(
f"WHEN {old_id} THEN {new_id}"
for old_id, new_id in parent_mapping.items()
)
where_values = ", ".join(str(old_id) for old_id in parent_mapping)
update_sql = (
f"UPDATE {quoted_child_table}"
f" SET {quoted_fk} = CASE {quoted_fk}"
f" {case_parts} ELSE {quoted_fk} END"
f" WHERE {quoted_fk} IN ({where_values})"
)
cursor.execute(update_sql)
if cursor.rowcount > 0:
tables_affected.add(child_schema.name)
cursor.execute("PRAGMA foreign_keys = ON")
cursor.execute("PRAGMA foreign_key_check")
fk_violations = cursor.fetchall()
if fk_violations:
raise sqlite3.IntegrityError(
"Foreign key integrity check failed"
f" after ID remapping: {fk_violations[0]}"
)
connection.commit()
except Exception:
connection.rollback()
cursor.execute("PRAGMA foreign_keys = ON")
raise
return MutationResult(
mutation_name="remap_ids",
tables_affected=sorted(tables_affected),
rows_added=rows_updated,
success=True,
)
def duplicate_bridge_rows(
db_path: str,
schemas: list[TableSchema],
bridge_tables: list[str],
) -> MutationResult:
"""Duplicate bridge-table rows, skipping rows blocked by constraints."""
if not bridge_tables:
return MutationResult(
mutation_name="duplicate_bridge_rows",
tables_affected=[],
rows_added=0,
success=True,
)
schema_names = {schema.name for schema in schemas}
rows_added = 0
tables_affected: list[str] = []
with sqlite3.connect(db_path) as connection:
cursor = connection.cursor()
for table_name in bridge_tables:
if table_name not in schema_names:
continue
quoted_table = _quote_identifier(table_name)
cursor.execute(f"PRAGMA table_info({quoted_table})")
table_info = cursor.fetchall()
if not table_info:
continue
column_names = [str(row[1]) for row in table_info]
quoted_columns = ", ".join(_quote_identifier(name) for name in column_names)
placeholders = ", ".join("?" for _ in column_names)
cursor.execute(f"SELECT {quoted_columns} FROM {quoted_table}")
existing_rows = cursor.fetchall()
inserted_for_table = 0
for row in existing_rows:
try:
insert_sql = (
f"INSERT INTO {quoted_table}"
f" ({quoted_columns})"
f" VALUES ({placeholders})"
)
cursor.execute(insert_sql, row)
inserted_for_table += 1
except sqlite3.IntegrityError:
continue
if inserted_for_table > 0:
tables_affected.append(table_name)
rows_added += inserted_for_table
connection.commit()
return MutationResult(
mutation_name="duplicate_bridge_rows",
tables_affected=sorted(tables_affected),
rows_added=rows_added,
success=True,
)