Rifqi Hafizuddin commited on
Commit
4353929
·
1 Parent(s): 65a5c6b

[KM-437][DB] add mysql, sqlserver, bigquery, snowflake connections

Browse files
src/api/v1/db_client.py CHANGED
@@ -340,12 +340,7 @@ async def ingest_database_client(
340
  try:
341
  with db_pipeline_service.engine_scope(
342
  db_type=client.db_type,
343
- host=creds["host"],
344
- port=creds["port"],
345
- database=creds["database"],
346
- username=creds["username"],
347
- password=creds["password"],
348
- ssl_mode=creds.get("ssl_mode"),
349
  ) as engine:
350
  total = await db_pipeline_service.run(user_id=user_id, engine=engine)
351
  except NotImplementedError as e:
 
340
  try:
341
  with db_pipeline_service.engine_scope(
342
  db_type=client.db_type,
343
+ credentials=creds,
 
 
 
 
 
344
  ) as engine:
345
  total = await db_pipeline_service.run(user_id=user_id, engine=engine)
346
  except NotImplementedError as e:
src/pipeline/db_pipeline/db_pipeline_service.py CHANGED
@@ -10,7 +10,7 @@ async vector writes stay on the event loop.
10
 
11
  import asyncio
12
  from contextlib import contextmanager
13
- from typing import Iterator, Optional
14
 
15
  from langchain_core.documents import Document as LangChainDocument
16
  from sqlalchemy import URL, create_engine
@@ -27,70 +27,106 @@ logger = get_logger("db_pipeline")
27
  class DbPipelineService:
28
  """End-to-end DB ingestion: connect -> introspect -> profile -> embed -> store."""
29
 
30
- def connect(
31
- self,
32
- db_type: DbType,
33
- host: str,
34
- port: int,
35
- database: str,
36
- username: str,
37
- password: str,
38
- ssl_mode: Optional[str] = None,
39
- ) -> Engine:
40
  """Build a SQLAlchemy engine for the user's database.
41
 
42
- Supabase aliases to the Postgres driver (same URL shape). Other
43
- engines raise NotImplementedError until their connector is added.
 
44
 
45
- `ssl_mode` maps to libpq's `sslmode` query param for postgres/supabase
46
- (required for managed DBs like Neon/Supabase: "require", "verify-ca",
47
- "verify-full"). Ignored for other db_types until those connectors land.
48
  """
49
- logger.info(
50
- "connecting to user db", db_type=db_type, host=host, port=port, database=database
51
- )
52
  if db_type in ("postgres", "supabase"):
53
- query = {"sslmode": ssl_mode} if ssl_mode else {}
 
 
54
  url = URL.create(
55
  drivername="postgresql+psycopg2",
56
- username=username,
57
- password=password,
58
- host=host,
59
- port=port,
60
- database=database,
61
  query=query,
62
  )
63
  return create_engine(url)
64
- elif db_type == "mysql":
65
- raise NotImplementedError("MySQL support coming soon")
66
- elif db_type == "sqlserver":
67
- raise NotImplementedError("SQL Server support coming soon")
68
- elif db_type == "bigquery":
69
- raise NotImplementedError("BigQuery support coming soon")
70
- elif db_type == "snowflake":
71
- raise NotImplementedError("Snowflake support coming soon")
72
- else:
73
- raise ValueError(f"Unsupported db_type: {db_type}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  @contextmanager
76
  def engine_scope(
77
- self,
78
- db_type: DbType,
79
- host: str,
80
- port: int,
81
- database: str,
82
- username: str,
83
- password: str,
84
- ssl_mode: Optional[str] = None,
85
  ) -> Iterator[Engine]:
86
  """Yield a connected Engine and dispose its pool on exit.
87
 
88
  API callers should prefer this over raw `connect(...)` so user DB
89
  connection pools do not leak between pipeline runs.
90
  """
91
- engine = self.connect(
92
- db_type, host, port, database, username, password, ssl_mode
93
- )
94
  try:
95
  yield engine
96
  finally:
 
10
 
11
  import asyncio
12
  from contextlib import contextmanager
13
+ from typing import Any, Iterator, Optional
14
 
15
  from langchain_core.documents import Document as LangChainDocument
16
  from sqlalchemy import URL, create_engine
 
27
  class DbPipelineService:
28
  """End-to-end DB ingestion: connect -> introspect -> profile -> embed -> store."""
29
 
30
+ def connect(self, db_type: DbType, credentials: dict[str, Any]) -> Engine:
 
 
 
 
 
 
 
 
 
31
  """Build a SQLAlchemy engine for the user's database.
32
 
33
+ `credentials` is the plaintext dict matching the per-type schema in
34
+ `src/models/credentials.py`. BigQuery/Snowflake auth models differ
35
+ from host/port/user/pass, so every shape flows through one dict.
36
 
37
+ Optional driver imports (snowflake-sqlalchemy, json for BigQuery) are
38
+ done lazily so an env missing one driver doesn't break module import.
 
39
  """
40
+ logger.info("connecting to user db", db_type=db_type)
41
+
 
42
  if db_type in ("postgres", "supabase"):
43
+ query = (
44
+ {"sslmode": credentials["ssl_mode"]} if credentials.get("ssl_mode") else {}
45
+ )
46
  url = URL.create(
47
  drivername="postgresql+psycopg2",
48
+ username=credentials["username"],
49
+ password=credentials["password"],
50
+ host=credentials["host"],
51
+ port=credentials["port"],
52
+ database=credentials["database"],
53
  query=query,
54
  )
55
  return create_engine(url)
56
+
57
+ if db_type == "mysql":
58
+ url = URL.create(
59
+ drivername="mysql+pymysql",
60
+ username=credentials["username"],
61
+ password=credentials["password"],
62
+ host=credentials["host"],
63
+ port=credentials["port"],
64
+ database=credentials["database"],
65
+ )
66
+ # pymysql: empty-dict ssl arg flips SSL on with defaults.
67
+ connect_args = {"ssl": {}} if credentials.get("ssl", True) else {}
68
+ return create_engine(url, connect_args=connect_args)
69
+
70
+ if db_type == "sqlserver":
71
+ # `driver` applies to pyodbc only; we ship pymssql. Accept-and-ignore
72
+ # keeps the credential schema stable.
73
+ if credentials.get("driver"):
74
+ logger.info(
75
+ "sqlserver driver hint ignored (using pymssql)",
76
+ driver=credentials["driver"],
77
+ )
78
+ url = URL.create(
79
+ drivername="mssql+pymssql",
80
+ username=credentials["username"],
81
+ password=credentials["password"],
82
+ host=credentials["host"],
83
+ port=credentials["port"],
84
+ database=credentials["database"],
85
+ )
86
+ return create_engine(url)
87
+
88
+ if db_type == "bigquery":
89
+ import json
90
+
91
+ sa_info = json.loads(credentials["service_account_json"])
92
+ # sqlalchemy-bigquery URL shape: bigquery://<project>/<dataset>
93
+ url = f"bigquery://{credentials['project_id']}/{credentials['dataset_id']}"
94
+ return create_engine(
95
+ url,
96
+ credentials_info=sa_info,
97
+ location=credentials.get("location", "US"),
98
+ )
99
+
100
+ if db_type == "snowflake":
101
+ from snowflake.sqlalchemy import URL as SnowflakeURL
102
+
103
+ url = SnowflakeURL(
104
+ account=credentials["account"],
105
+ user=credentials["username"],
106
+ password=credentials["password"],
107
+ database=credentials["database"],
108
+ schema=(
109
+ credentials.get("db_schema")
110
+ or credentials.get("schema")
111
+ or "PUBLIC"
112
+ ),
113
+ warehouse=credentials["warehouse"],
114
+ role=credentials.get("role") or "",
115
+ )
116
+ return create_engine(url)
117
+
118
+ raise ValueError(f"Unsupported db_type: {db_type}")
119
 
120
  @contextmanager
121
  def engine_scope(
122
+ self, db_type: DbType, credentials: dict[str, Any]
 
 
 
 
 
 
 
123
  ) -> Iterator[Engine]:
124
  """Yield a connected Engine and dispose its pool on exit.
125
 
126
  API callers should prefer this over raw `connect(...)` so user DB
127
  connection pools do not leak between pipeline runs.
128
  """
129
+ engine = self.connect(db_type, credentials)
 
 
130
  try:
131
  yield engine
132
  finally:
src/pipeline/db_pipeline/extractor.py CHANGED
@@ -18,6 +18,28 @@ logger = get_logger("db_extractor")
18
 
19
  TOP_VALUES_THRESHOLD = 0.05 # show top values if distinct_ratio <= 5%
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  def _qi(engine: Engine, name: str) -> str:
23
  """Dialect-correct identifier quoting (schema.table also handled if dotted)."""
@@ -95,11 +117,10 @@ def profile_column(
95
  select_cols.append(f"MIN({qc}) AS min_val")
96
  select_cols.append(f"MAX({qc}) AS max_val")
97
  select_cols.append(f"AVG({qc}) AS mean_val")
98
- # PERCENTILE_CONT is supported by Postgres and SQL Server; MySQL would need
99
- # a dialect-specific fallback when that connector is added.
100
- select_cols.append(
101
- f"PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY {qc}) AS median_val"
102
- )
103
  stats = pd.read_sql(f"SELECT {', '.join(select_cols)} FROM {qt}", engine)
104
 
105
  null_count = int(stats.iloc[0]["nulls"])
@@ -116,17 +137,21 @@ def profile_column(
116
  profile["min"] = stats.iloc[0]["min_val"]
117
  profile["max"] = stats.iloc[0]["max_val"]
118
  profile["mean"] = stats.iloc[0]["mean_val"]
119
- profile["median"] = stats.iloc[0]["median_val"]
 
120
 
121
  if 0 < distinct_ratio <= TOP_VALUES_THRESHOLD:
122
- top = pd.read_sql(
123
- f"SELECT {qc}, COUNT(*) AS cnt FROM {qt} "
124
- f"GROUP BY {qc} ORDER BY cnt DESC LIMIT 10",
125
  engine,
 
 
 
 
126
  )
 
127
  profile["top_values"] = list(zip(top[col_name].tolist(), top["cnt"].tolist()))
128
 
129
- sample = pd.read_sql(f"SELECT {qc} FROM {qt} LIMIT 5", engine)
130
  profile["sample_values"] = sample[col_name].tolist()
131
 
132
  return profile
@@ -178,7 +203,9 @@ def build_text(table_name: str, row_count: int, col: dict, profile: dict) -> str
178
  text += f"Distinct count: {profile['distinct_count']} ({profile['distinct_ratio']:.1%})\n"
179
  if "min" in profile:
180
  text += f"Min: {profile['min']}, Max: {profile['max']}\n"
181
- text += f"Mean: {profile['mean']}, Median: {profile['median']}\n"
 
 
182
  if "top_values" in profile:
183
  top_str = ", ".join(f"{v} ({c})" for v, c in profile["top_values"])
184
  text += f"Top values: {top_str}\n"
 
18
 
19
  TOP_VALUES_THRESHOLD = 0.05 # show top values if distinct_ratio <= 5%
20
 
21
+ # Dialects where PERCENTILE_CONT(...) WITHIN GROUP is supported as an aggregate.
22
+ # MySQL has no percentile aggregate; BigQuery has PERCENTILE_CONT only as an
23
+ # analytic (window) function — both drop median and keep min/max/mean.
24
+ _MEDIAN_DIALECTS = frozenset({"postgresql", "mssql", "snowflake"})
25
+
26
+
27
+ def _supports_median(engine: Engine) -> bool:
28
+ return engine.dialect.name in _MEDIAN_DIALECTS
29
+
30
+
31
+ def _head_query(
32
+ engine: Engine,
33
+ select_clause: str,
34
+ from_clause: str,
35
+ n: int,
36
+ order_by: str = "",
37
+ ) -> str:
38
+ """LIMIT/TOP-equivalent head query for the engine's dialect."""
39
+ if engine.dialect.name == "mssql":
40
+ return f"SELECT TOP {n} {select_clause} FROM {from_clause} {order_by}".strip()
41
+ return f"SELECT {select_clause} FROM {from_clause} {order_by} LIMIT {n}".strip()
42
+
43
 
44
  def _qi(engine: Engine, name: str) -> str:
45
  """Dialect-correct identifier quoting (schema.table also handled if dotted)."""
 
117
  select_cols.append(f"MIN({qc}) AS min_val")
118
  select_cols.append(f"MAX({qc}) AS max_val")
119
  select_cols.append(f"AVG({qc}) AS mean_val")
120
+ if _supports_median(engine):
121
+ select_cols.append(
122
+ f"PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY {qc}) AS median_val"
123
+ )
 
124
  stats = pd.read_sql(f"SELECT {', '.join(select_cols)} FROM {qt}", engine)
125
 
126
  null_count = int(stats.iloc[0]["nulls"])
 
137
  profile["min"] = stats.iloc[0]["min_val"]
138
  profile["max"] = stats.iloc[0]["max_val"]
139
  profile["mean"] = stats.iloc[0]["mean_val"]
140
+ if _supports_median(engine):
141
+ profile["median"] = stats.iloc[0]["median_val"]
142
 
143
  if 0 < distinct_ratio <= TOP_VALUES_THRESHOLD:
144
+ top_sql = _head_query(
 
 
145
  engine,
146
+ select_clause=f"{qc}, COUNT(*) AS cnt",
147
+ from_clause=f"{qt} GROUP BY {qc}",
148
+ n=10,
149
+ order_by="ORDER BY cnt DESC",
150
  )
151
+ top = pd.read_sql(top_sql, engine)
152
  profile["top_values"] = list(zip(top[col_name].tolist(), top["cnt"].tolist()))
153
 
154
+ sample = pd.read_sql(_head_query(engine, qc, qt, 5), engine)
155
  profile["sample_values"] = sample[col_name].tolist()
156
 
157
  return profile
 
203
  text += f"Distinct count: {profile['distinct_count']} ({profile['distinct_ratio']:.1%})\n"
204
  if "min" in profile:
205
  text += f"Min: {profile['min']}, Max: {profile['max']}\n"
206
+ text += f"Mean: {profile['mean']}\n"
207
+ if profile.get("median") is not None:
208
+ text += f"Median: {profile['median']}\n"
209
  if "top_values" in profile:
210
  top_str = ", ".join(f"{v} ({c})" for v, c in profile["top_values"])
211
  text += f"Top values: {top_str}\n"