Spaces:
Running
Running
File size: 3,600 Bytes
dc59b01 4f2cd24 dc59b01 4f2cd24 dc59b01 4f2cd24 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 | import sqlite3
class SchemaEncoder:
def __init__(self, db_root):
self.db_root = db_root
def get_tables_and_columns(self, db_id):
# FIXED PATH
db_path = self.db_root / f"{db_id}.sqlite"
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
tables = cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table';"
).fetchall()
schema = {}
for (table,) in tables:
cols = cursor.execute(f"PRAGMA table_info({table});").fetchall()
col_names = [c[1] for c in cols]
schema[table] = col_names
conn.close()
return schema
# -----------------------------------
# Strategy 1: Structured
# -----------------------------------
def structured_schema(self, db_id):
schema = self.get_tables_and_columns(db_id)
lines = []
for table, cols in schema.items():
lines.append(f"{table}({', '.join(cols)})")
return "\n".join(lines)
# -----------------------------------
# Strategy 2: Natural Language
# -----------------------------------
def natural_language_schema(self, db_id):
schema = self.get_tables_and_columns(db_id)
lines = []
for table, cols in schema.items():
col_text = ", ".join(cols)
lines.append(f"The table '{table}' contains the columns: {col_text}.")
return "\n".join(lines)
# import sqlite3
# import re
# def build_schema_graph(schema_text):
# """
# Parses a structured schema text string into a dictionary graph.
# Matches formats like: table_name(col1, col2, col3)
# """
# tables = {}
# for match in re.findall(r'(\w+)\s*\((.*?)\)', schema_text):
# table = match[0]
# # Extracts just the column names, ignoring potential types or constraints
# cols = [c.strip().split()[0] for c in match[1].split(",")]
# tables[table] = cols
# return tables
# class SchemaEncoder:
# def __init__(self, db_root):
# self.db_root = db_root
# def get_tables_and_columns(self, db_id):
# # Assuming db_root is a pathlib.Path object based on the syntax
# db_path = self.db_root / db_id / f"{db_id}.sqlite"
# conn = sqlite3.connect(db_path)
# cursor = conn.cursor()
# tables = cursor.execute(
# "SELECT name FROM sqlite_master WHERE type='table';"
# ).fetchall()
# schema = {}
# for (table,) in tables:
# cols = cursor.execute(f"PRAGMA table_info({table});").fetchall()
# col_names = [c[1] for c in cols]
# schema[table] = col_names
# conn.close()
# return schema
# # -----------------------------------
# # Strategy 1: Structured (current)
# # -----------------------------------
# def structured_schema(self, db_id):
# schema = self.get_tables_and_columns(db_id)
# lines = []
# for table, cols in schema.items():
# lines.append(f"{table}({', '.join(cols)})")
# return "\n".join(lines)
# # -----------------------------------
# # Strategy 2: Natural Language
# # -----------------------------------
# def natural_language_schema(self, db_id):
# schema = self.get_tables_and_columns(db_id)
# lines = []
# for table, cols in schema.items():
# col_text = ", ".join(cols)
# lines.append(f"The table '{table}' contains the columns: {col_text}.")
# return "\n".join(lines) |