import sqlite3 import torch import re import time from pathlib import Path from transformers import AutoTokenizer, AutoModelForSeq2SeqLM from peft import PeftModel from src.sql_validator import SQLValidator from src.schema_encoder import SchemaEncoder PROJECT_ROOT = Path(__file__).resolve().parents[1] # ================================ # DATABASE PATH AUTO DETECTION # ================================ if (PROJECT_ROOT / "data/database").exists(): DB_ROOT = PROJECT_ROOT / "data/database" else: DB_ROOT = PROJECT_ROOT / "final_databases" def normalize_question(q: str): q = q.lower().strip() q = re.sub(r"distinct\s+(\d+)", r"\1 distinct", q) q = re.sub(r"\s+", " ", q) return q def semantic_fix(question, sql): q = question.lower().strip() s = sql.lower() num_match = re.search(r'\b(?:show|list|top|limit|get|first|last)\s+(\d+)\b', q) if num_match and "limit" not in s and "count(" not in s: limit_val = num_match.group(1) sql = sql.rstrip(";") sql = f"{sql.strip()} LIMIT {limit_val}" return sql class Text2SQLEngine: def __init__(self, adapter_path=None, base_model_name="Salesforce/codet5-base", use_lora=True): self.device = "mps" if torch.backends.mps.is_available() else ( "cuda" if torch.cuda.is_available() else "cpu" ) self.validator = SQLValidator(DB_ROOT) self.schema_encoder = SchemaEncoder(DB_ROOT) self.dml_keywords = r'\b(delete|update|insert|drop|alter|truncate)\b' print("Loading base model...") base = AutoModelForSeq2SeqLM.from_pretrained(base_model_name) if not use_lora: self.tokenizer = AutoTokenizer.from_pretrained(base_model_name) self.model = base.to(self.device) self.model.eval() return if (PROJECT_ROOT / "checkpoints/best_rlhf_model").exists(): adapter_path = PROJECT_ROOT / "checkpoints/best_rlhf_model" else: adapter_path = PROJECT_ROOT / "best_rlhf_model" adapter_path = adapter_path.resolve() print("Loading tokenizer and LoRA adapter...") try: self.tokenizer = AutoTokenizer.from_pretrained( str(adapter_path), local_files_only=True ) except Exception: self.tokenizer = AutoTokenizer.from_pretrained(base_model_name) self.model = PeftModel.from_pretrained(base, str(adapter_path)).to(self.device) self.model.eval() print("āœ… RLHF model ready\n") def build_prompt(self, question, schema): return f"""You are an expert SQL generator. Database schema: {schema} Generate a valid SQLite query for the question. Question: {question} SQL: """ def get_schema(self, db_id): return self.schema_encoder.structured_schema(db_id) def extract_sql(self, text: str): text = text.strip() if "SQL:" in text: text = text.split("SQL:")[-1] match = re.search(r"select[\s\S]*", text, re.IGNORECASE) if match: text = match.group(0) return text.split(";")[0].strip() def clean_sql(self, sql: str): sql = sql.replace('"', "'") sql = re.sub(r"\s+", " ", sql) return sql.strip() def generate_sql(self, prompt): inputs = self.tokenizer( prompt, return_tensors="pt", truncation=True, max_length=512 ).to(self.device) with torch.no_grad(): outputs = self.model.generate( **inputs, max_new_tokens=128, num_beams=5, early_stopping=True ) decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=True) return self.clean_sql(self.extract_sql(decoded)) def execute_sql(self, question, sql, db_id): if re.search(self.dml_keywords, sql, re.IGNORECASE): return sql, [], [], "āŒ Security Alert" # FIXED DATABASE PATH db_path = DB_ROOT / f"{db_id}.sqlite" sql = self.clean_sql(sql) sql = semantic_fix(question, sql) try: conn = sqlite3.connect(db_path) cursor = conn.cursor() cursor.execute(sql) rows = cursor.fetchall() columns = [d[0] for d in cursor.description] if cursor.description else [] conn.close() return sql, columns, rows, None except Exception as e: return sql, [], [], str(e) def ask(self, question, db_id): question = normalize_question(question) if re.search(self.dml_keywords, question, re.IGNORECASE): return { "question": question, "sql": "-- BLOCKED", "columns": [], "rows": [], "error": "Malicious prompt" } schema = self.get_schema(db_id) prompt = self.build_prompt(question, schema) raw_sql = self.generate_sql(prompt) final_sql, cols, rows, error = self.execute_sql(question, raw_sql, db_id) return { "question": question, "sql": final_sql, "columns": cols, "rows": rows, "error": error } _engine = None def get_engine(): global _engine if _engine is None: _engine = Text2SQLEngine() return _engine # import sqlite3 # import torch # import re # import os # from pathlib import Path # from transformers import AutoTokenizer, AutoModelForSeq2SeqLM # from peft import PeftModel # from src.sql_validator import SQLValidator # from src.schema_encoder import SchemaEncoder # Removed build_schema_graph import # PROJECT_ROOT = Path(__file__).resolve().parents[1] # # ================================ # # DATABASE PATH AUTO DETECTION # # ================================ # if (PROJECT_ROOT / "data/database").exists(): # DB_ROOT = PROJECT_ROOT / "data/database" # else: # DB_ROOT = PROJECT_ROOT / "final_databases" # # ========================================== # # SCHEMA PARSING # # ========================================== # def build_schema_graph(schema_text): # """ # Parses a structured schema text string into a dictionary graph. # Matches formats like: table_name(col1, col2, col3) # """ # tables = {} # for match in re.findall(r'(\w+)\s*\((.*?)\)', schema_text): # table = match[0] # cols = [c.strip().split()[0] for c in match[1].split(",")] # tables[table] = cols # return tables # # ========================================== # # INPUT VALIDATION & RELEVANCE # # ========================================== # def is_valid_question(q: str): # q = q.strip().lower() # if len(q) < 3: # return False # words = re.findall(r"[a-zA-Z]+", q) # if len(words) < 1: # return False # return True # def is_relevant_to_db(question: str, schema_graph: dict): # q_words = set(re.findall(r'\b[a-z]{3,}\b', question.lower())) # stop_words = {"show", "list", "all", "and", "the", "get", "find", "how", "many", "what", "where", "which", "who", "give", "display", "count", "from", "for", "with", "that", "have", "has", "are", "there"} # q_words = q_words - stop_words # if not q_words: # return True # schema_words = set() # for table, cols in schema_graph.items(): # schema_words.update(re.findall(r'\b[a-z]{3,}\b', table.lower())) # for col in cols: # schema_words.update(re.findall(r'\b[a-z]{3,}\b', col.lower())) # synonyms = { # "customer": ["client", "buyer", "shopper", "person", "people", "user"], # "employee": ["staff", "worker", "boss", "manager", "person", "people"], # "track": ["song", "music", "audio", "tune"], # "album": ["record", "cd", "music"], # "artist": ["singer", "band", "musician", "creator"], # "invoice": ["bill", "receipt", "purchase", "sale", "order", "buy", "bought", "cost"], # "city": ["town", "location", "place"], # "country": ["nation", "location", "place"], # "flight": ["plane", "airline", "trip", "fly", "airport"], # "student": ["pupil", "learner", "kid", "child"], # "club": ["group", "organization", "team"], # "course": ["class", "subject"], # "cinema": ["movie", "film", "theater", "screen"] # } # extended_schema_words = set(schema_words) # for db_word in schema_words: # if db_word in synonyms: # extended_schema_words.update(synonyms[db_word]) # extended_schema_words.update({"id", "name", "total", "sum", "average", "avg", "min", "max", "number", "amount", "record", "data", "info", "information", "detail", "first", "last", "most", "least", "cheapest", "expensive", "best"}) # for qw in q_words: # qw_singular = qw[:-1] if qw.endswith('s') else qw # if qw in extended_schema_words or qw_singular in extended_schema_words: # return True # return False # def normalize_question(q: str): # return re.sub(r"\s+", " ", q.lower().strip()) # def semantic_fix(question, sql): # q = question.lower() # num_match = re.search(r'\b(?:show|list|top|get)\s+(\d+)\b', q) # if num_match and "limit" not in sql.lower(): # sql = f"{sql} LIMIT {num_match.group(1)}" # return sql # # ========================================== # # SCHEMA CONSTRAINTS (SIMULATED LOGIT BLOCKING) # # ========================================== # def apply_schema_constraints(sql, schema_graph): # sql = sql.lower() # used_tables = [t[1] for t in re.findall(r'(from|join)\s+(\w+)', sql)] # for t in used_tables: # if t not in schema_graph: # return None # valid_columns = set() # for cols in schema_graph.values(): # valid_columns.update(cols) # col_blocks = re.findall(r'select\s+(.*?)\s+from', sql) # for block in col_blocks: # for c in block.split(","): # c = c.strip().split()[-1] # if "." in c: # c = c.split(".")[-1] # if c != "*" and "(" not in c and c != "": # if c not in valid_columns: # return None # return sql # # ========================================== # # ENGINE # # ========================================== # class Text2SQLEngine: # def __init__(self, # adapter_path="checkpoints/best_rlhf_model_2", # base_model_name="Salesforce/codet5-base", # use_lora=True, # use_constrained_decoding=False): # self.device = "mps" if torch.backends.mps.is_available() else ( # "cuda" if torch.cuda.is_available() else "cpu" # ) # self.validator = SQLValidator(DB_ROOT) # self.schema_encoder = SchemaEncoder(DB_ROOT) # self.use_constrained_decoding = use_constrained_decoding # self.dml_keywords = r'\b(delete|update|insert|drop|alter|truncate|create)\b' # print(f"\nšŸ“¦ Loading model on {self.device}...") # base = AutoModelForSeq2SeqLM.from_pretrained(base_model_name) # # Override the redundant special tokens to prevent the tokenizer crash # self.tokenizer = AutoTokenizer.from_pretrained( # base_model_name, # use_fast=False, # additional_special_tokens=[] # ) # # šŸ”„ FIXED LOADA ADAPTER PATH LOGIC # if use_lora: # if adapter_path and (PROJECT_ROOT / adapter_path).exists(): # adapter_path = PROJECT_ROOT / adapter_path # elif (PROJECT_ROOT / "checkpoints/best_rlhf_model_2").exists(): # adapter_path = PROJECT_ROOT / "checkpoints/best_rlhf_model_2" # else: # adapter_path = PROJECT_ROOT / "best_rlhf_model_2" # adapter_path = adapter_path.resolve() # if adapter_path.exists(): # try: # self.model = PeftModel.from_pretrained( # base, # str(adapter_path), # local_files_only=True # ).to(self.device) # print(f"āœ… LoRA loaded from {adapter_path}") # except Exception as e: # print(f"āš ļø LoRA load failed: {e}") # self.model = base.to(self.device) # else: # print(f"āš ļø Adapter not found at {adapter_path}, using base model") # self.model = base.to(self.device) # else: # self.model = base.to(self.device) # self.model.eval() # def build_prompt(self, question, schema): # return f""" # You are an expert SQL generator. # IMPORTANT: # - Use correct tables and columns # - Use JOINs when needed # Schema: # {schema} # Question: # {question} # SQL: # """ # def get_schema(self, db_id): # return self.schema_encoder.structured_schema(db_id) # def extract_sql(self, text): # match = re.search(r"(select|with)[\s\S]*", text, re.IGNORECASE) # return match.group(0).split(";")[0].strip() if match else "" # def clean_sql(self, sql): # return re.sub(r"\s+", " ", sql.replace('"', "'")).strip() # def generate_sql(self, prompt): # inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) # with torch.no_grad(): # outputs = self.model.generate( # **inputs, # max_new_tokens=128, # num_beams=8, # length_penalty=0.8, # early_stopping=True # ) # decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=True) # return self.clean_sql(self.extract_sql(decoded)) # def execute_sql(self, question, sql, db_id): # if re.search(self.dml_keywords, sql, re.IGNORECASE): # return "", [], [], "āŒ Blocked malicious SQL (Contains INSERT/UPDATE/DELETE/DROP)" # # šŸ”„ FIXED DATABASE PATH # db_path = DB_ROOT / f"{db_id}.sqlite" # sql = semantic_fix(question, sql) # try: # conn = sqlite3.connect(db_path) # cursor = conn.cursor() # cursor.execute(sql) # rows = cursor.fetchall() # columns = [d[0] for d in cursor.description] if cursor.description else [] # conn.close() # return sql, columns, rows, None # except Exception as e: # return sql, [], [], str(e) # def ask(self, question, db_id): # question = normalize_question(question) # question_context = f"Database question: {question}" # if re.search(self.dml_keywords, question_context, re.IGNORECASE): # return {"sql": "", "error": "āŒ Blocked dangerous query from input text."} # if not is_valid_question(question_context): # return {"sql": "", "error": "āŒ Invalid input. Please type words."} # schema = self.get_schema(db_id) # schema_graph = build_schema_graph(schema) # if not is_relevant_to_db(question, schema_graph): # return {"sql": "", "error": "āŒ Question is completely out of domain for the selected database."} # sql = self.generate_sql(self.build_prompt(question_context, schema)) # if self.use_constrained_decoding: # filtered_sql = apply_schema_constraints(sql, schema_graph) # if filtered_sql is None: # constraint_prompt = f""" # Use ONLY valid schema. # Schema: # {schema} # Question: # {question_context} # SQL: # """ # sql_retry = self.generate_sql(constraint_prompt) # filtered_sql = apply_schema_constraints(sql_retry, schema_graph) # if filtered_sql: # sql = filtered_sql # else: # sql = sql_retry # final_sql, cols, rows, error = self.execute_sql(question_context, sql, db_id) # return { # "question": question_context, # "sql": final_sql, # "columns": cols, # "rows": rows, # "error": error # } # def get_engine( # adapter_path="checkpoints/best_rlhf_model_2", # base_model_name="Salesforce/codet5-base", # use_lora=True, # use_constrained=True # ): # return Text2SQLEngine( # adapter_path=adapter_path, # base_model_name=base_model_name, # use_lora=use_lora, # use_constrained_decoding=use_constrained # )