File size: 3,600 Bytes
dc59b01
 
 
 
 
 
 
 
 
4f2cd24
 
 
 
dc59b01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f2cd24
dc59b01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f2cd24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import sqlite3


class SchemaEncoder:

    def __init__(self, db_root):
        self.db_root = db_root

    def get_tables_and_columns(self, db_id):

        # FIXED PATH
        db_path = self.db_root / f"{db_id}.sqlite"

        conn = sqlite3.connect(db_path)
        cursor = conn.cursor()

        tables = cursor.execute(
            "SELECT name FROM sqlite_master WHERE type='table';"
        ).fetchall()

        schema = {}

        for (table,) in tables:
            cols = cursor.execute(f"PRAGMA table_info({table});").fetchall()
            col_names = [c[1] for c in cols]
            schema[table] = col_names

        conn.close()
        return schema

    # -----------------------------------
    # Strategy 1: Structured
    # -----------------------------------
    def structured_schema(self, db_id):
        schema = self.get_tables_and_columns(db_id)

        lines = []
        for table, cols in schema.items():
            lines.append(f"{table}({', '.join(cols)})")

        return "\n".join(lines)

    # -----------------------------------
    # Strategy 2: Natural Language
    # -----------------------------------
    def natural_language_schema(self, db_id):
        schema = self.get_tables_and_columns(db_id)

        lines = []
        for table, cols in schema.items():
            col_text = ", ".join(cols)
            lines.append(f"The table '{table}' contains the columns: {col_text}.")

        return "\n".join(lines)


# import sqlite3
# import re

# def build_schema_graph(schema_text):
#     """
#     Parses a structured schema text string into a dictionary graph.
#     Matches formats like: table_name(col1, col2, col3)
#     """
#     tables = {}
    
#     for match in re.findall(r'(\w+)\s*\((.*?)\)', schema_text):
#         table = match[0]
#         # Extracts just the column names, ignoring potential types or constraints
#         cols = [c.strip().split()[0] for c in match[1].split(",")]
#         tables[table] = cols
    
#     return tables


# class SchemaEncoder:

#     def __init__(self, db_root):
#         self.db_root = db_root

#     def get_tables_and_columns(self, db_id):
#         # Assuming db_root is a pathlib.Path object based on the syntax
#         db_path = self.db_root / db_id / f"{db_id}.sqlite"
#         conn = sqlite3.connect(db_path)
#         cursor = conn.cursor()

#         tables = cursor.execute(
#             "SELECT name FROM sqlite_master WHERE type='table';"
#         ).fetchall()

#         schema = {}

#         for (table,) in tables:
#             cols = cursor.execute(f"PRAGMA table_info({table});").fetchall()
#             col_names = [c[1] for c in cols]
#             schema[table] = col_names

#         conn.close()
#         return schema

#     # -----------------------------------
#     # Strategy 1: Structured (current)
#     # -----------------------------------
#     def structured_schema(self, db_id):
#         schema = self.get_tables_and_columns(db_id)

#         lines = []
#         for table, cols in schema.items():
#             lines.append(f"{table}({', '.join(cols)})")

#         return "\n".join(lines)

#     # -----------------------------------
#     # Strategy 2: Natural Language
#     # -----------------------------------
#     def natural_language_schema(self, db_id):
#         schema = self.get_tables_and_columns(db_id)

#         lines = []
#         for table, cols in schema.items():
#             col_text = ", ".join(cols)
#             lines.append(f"The table '{table}' contains the columns: {col_text}.")

#         return "\n".join(lines)