| from dotenv import load_dotenv |
| import os |
| from sentence_transformers import SentenceTransformer |
| import gradio as gr |
| from sklearn.metrics.pairwise import cosine_similarity |
| import google.generativeai as genai |
| import os |
| from dotenv import load_dotenv |
| import pandas as pd |
| load_dotenv() |
| |
| GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") |
| if not GEMINI_API_KEY: |
| print("Warning: GEMINI_API_KEY not set in environment. Set it in your .env file or system env vars.") |
|
|
| genai.configure(api_key=GEMINI_API_KEY) |
|
|
| |
| dataset_folder = "./data" |
|
|
| |
| if not os.path.exists(dataset_folder): |
| print(f"Warning: Dataset folder '{dataset_folder}' not found. Using current directory instead.") |
| dataset_folder = "." |
|
|
| |
| print("Available files:", os.listdir(dataset_folder)) |
|
|
| import warnings |
|
|
| |
| warnings.simplefilter("ignore", category=pd.errors.DtypeWarning) |
|
|
| |
| dataframes = [] |
| for file in os.listdir(dataset_folder): |
| if file.endswith(".csv"): |
| try: |
| path = os.path.join(dataset_folder, file) |
|
|
| |
| try: |
| sample_df = pd.read_csv(path, nrows=5, encoding="utf-8") |
| except UnicodeDecodeError: |
| sample_df = pd.read_csv(path, nrows=5, encoding="latin1") |
|
|
| column_types = {col: str for col in sample_df.columns} |
|
|
| try: |
| df = pd.read_csv(path, dtype=column_types, low_memory=False, encoding="utf-8") |
| except UnicodeDecodeError: |
| df = pd.read_csv(path, dtype=column_types, low_memory=False, encoding="latin1") |
|
|
| df = df.fillna('') |
| dataframes.append(df) |
|
|
| except Exception as e: |
| print(f"Error reading {file}: {e}") |
|
|
| |
| if dataframes: |
| full_data = pd.concat(dataframes, ignore_index=True) |
| else: |
| print("Warning: No valid CSV files found in the dataset folder.") |
| full_data = pd.DataFrame() |
| |
|
|
| def load_dataset_metadata(dataset_folder): |
| """Loads metadata from all CSV files in the dataset folder.""" |
| dataframes = [] |
| metadata_list = [] |
| |
| for file in os.listdir(dataset_folder): |
| if file.endswith(".csv"): |
| df = pd.read_csv(os.path.join(dataset_folder, file)) |
| dataframes.append((file, df)) |
|
|
| |
| columns = df.columns.tolist() |
| table_metadata = f""" |
| Table: {file.replace('.csv', '')} |
| Columns: |
| {', '.join(columns)} |
| """ |
| metadata_list.append(table_metadata) |
| |
| return dataframes, metadata_list |
|
|
| def create_metadata_embeddings(metadata_list): |
| """Creates embeddings for all table metadata.""" |
| model = SentenceTransformer('all-MiniLM-L6-v2') |
| embeddings = model.encode(metadata_list) |
| return embeddings, model |
|
|
| def find_best_fit(embeddings, model, user_query, metadata_list): |
| """Finds the best matching table based on user query.""" |
| query_embedding = model.encode([user_query]) |
| similarities = cosine_similarity(query_embedding, embeddings) |
| best_match_index = similarities.argmax() |
| return metadata_list[best_match_index] |
|
|
| def create_prompt(user_query, table_metadata): |
| """Generates a direct and structured SQL prompt with stricter formatting.""" |
| system_prompt = f""" |
| You are an AI assistant that generates precise SQL queries based on user questions. |
| |
| **Table Name & Columns:** |
| {table_metadata} |
| |
| **User Query:** |
| {user_query} |
| |
| **Output Format (STRICT):** |
| - Provide ONLY the SQL query. |
| - Do NOT include explanations, comments, or unnecessary text. |
| - Ensure the table and column names match exactly. |
| - If the query is impossible, return: "ERROR: Unable to generate query." |
| |
| **Example Queries:** |
| - User: "Show all startups founded in 2020." |
| - AI Response: SELECT * FROM startups WHERE founded_year = 2020; |
| |
| - User: "List the top 5 startups by total funding." |
| - AI Response: SELECT name, total_funding FROM startups ORDER BY total_funding DESC LIMIT 5; |
| """ |
| return system_prompt |
|
|
|
|
| def generate_sql_query(system_prompt): |
| """Uses Gemini API to generate an SQL query.""" |
| try: |
| |
| model = genai.GenerativeModel("gemini-2.5-pro") |
|
|
| |
| response = model.generate_content(system_prompt) |
|
|
| |
| print("🔍 Full API Response:", response) |
|
|
| |
| result = response.text.strip() |
| print(f"✅ AI Response: {result}") |
|
|
| |
| if result.lower().startswith("select"): |
| return result |
| else: |
| print("⚠️ Gemini did not generate a valid SQL query.") |
| return "⚠️ Invalid SQL query generated." |
|
|
| except Exception as e: |
| print(f"❌ Gemini API Error: {e}") |
| return "⚠️ API failed. Check logs." |
|
|
| def response(user_query, dataset_folder): |
| """Processes the user query and returns an SQL query.""" |
| dataframes, metadata_list = load_dataset_metadata(dataset_folder) |
| embeddings, model = create_metadata_embeddings(metadata_list) |
| table_metadata = find_best_fit(embeddings, model, user_query, metadata_list) |
| system_prompt = create_prompt(user_query, table_metadata) |
| return generate_sql_query(system_prompt) |
|
|
| dataset_folder = "./data" |
| user_query = "Show me the top 10 startups with the highest funding." |
|
|
| def sql_query_interface(user_query): |
| return response(user_query, dataset_folder) |
|
|
| |
| iface = gr.Interface( |
| fn=sql_query_interface, |
| inputs=gr.Textbox(label="Enter your query"), |
| outputs=gr.Textbox(label="Generated SQL Query"), |
| title="AI-Powered SQL Query Generator" |
| ) |
|
|
| |
| if __name__ == "__main__": |
| iface.launch() |