| | import gradio as gr |
| | import torch |
| | from transformers import AutoTokenizer, AutoModelForCausalLM |
| | import sqlparse |
| | import psutil |
| | import os |
| |
|
| | |
| | def get_available_memory(): |
| | return psutil.virtual_memory().available |
| |
|
| | model_name = "defog/llama-3-sqlcoder-8b" |
| |
|
| | |
| | tokenizer = AutoTokenizer.from_pretrained(model_name) |
| |
|
| | |
| | def load_model(): |
| | try: |
| | available_memory = get_available_memory() |
| | print(f"Available memory: {available_memory / 1e9:.1f} GB") |
| | |
| | |
| | if available_memory > 16e9: |
| | print("Loading model in float16...") |
| | model = AutoModelForCausalLM.from_pretrained( |
| | model_name, |
| | trust_remote_code=True, |
| | torch_dtype=torch.float16, |
| | device_map="cpu", |
| | use_cache=True, |
| | low_cpu_mem_usage=True |
| | ) |
| | else: |
| | print("Loading model in float32 with low memory usage...") |
| | model = AutoModelForCausalLM.from_pretrained( |
| | model_name, |
| | trust_remote_code=True, |
| | device_map="cpu", |
| | use_cache=True, |
| | low_cpu_mem_usage=True, |
| | torch_dtype=torch.float32 |
| | ) |
| | |
| | return model |
| | except Exception as e: |
| | print(f"Error loading model: {e}") |
| | return None |
| |
|
| | |
| | print("Loading model... This may take a few minutes on CPU.") |
| | model = load_model() |
| |
|
| | prompt_template = """<|begin_of_text|><|start_header_id|>user<|end_header_id|> |
| | |
| | Generate a SQL query to answer this question: `{question}` |
| | |
| | DDL statements: |
| | |
| | CREATE TABLE expenses ( |
| | id INTEGER PRIMARY KEY, -- Unique ID for each expense |
| | date DATE NOT NULL, -- Date when the expense occurred |
| | amount DECIMAL(10,2) NOT NULL, -- Amount spent |
| | category VARCHAR(50) NOT NULL, -- Category of expense (food, transport, utilities, etc.) |
| | description TEXT, -- Optional description of the expense |
| | payment_method VARCHAR(20), -- How the payment was made (cash, credit_card, debit_card, bank_transfer) |
| | user_id INTEGER -- ID of the user who made the expense |
| | ); |
| | |
| | CREATE TABLE categories ( |
| | id INTEGER PRIMARY KEY, -- Unique ID for each category |
| | name VARCHAR(50) UNIQUE NOT NULL, -- Category name (food, transport, utilities, entertainment, etc.) |
| | description TEXT -- Optional description of the category |
| | ); |
| | |
| | CREATE TABLE users ( |
| | id INTEGER PRIMARY KEY, -- Unique ID for each user |
| | username VARCHAR(50) UNIQUE NOT NULL, -- Username |
| | email VARCHAR(100) UNIQUE NOT NULL, -- Email address |
| | created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP -- When the user account was created |
| | ); |
| | |
| | CREATE TABLE budgets ( |
| | id INTEGER PRIMARY KEY, -- Unique ID for each budget |
| | user_id INTEGER, -- ID of the user who set the budget |
| | category VARCHAR(50), -- Category for which budget is set |
| | amount DECIMAL(10,2) NOT NULL, -- Budget amount |
| | period VARCHAR(20) DEFAULT 'monthly', -- Budget period (daily, weekly, monthly, yearly) |
| | start_date DATE, -- Budget start date |
| | end_date DATE -- Budget end date |
| | ); |
| | |
| | -- expenses.user_id can be joined with users.id |
| | -- expenses.category can be joined with categories.name |
| | -- budgets.user_id can be joined with users.id |
| | -- budgets.category can be joined with categories.name<|eot_id|><|start_header_id|>assistant<|end_header_id|> |
| | |
| | The following SQL query best answers the question `{question}`: |
| | ```sql |
| | """ |
| |
|
| | def generate_query(question): |
| | if model is None: |
| | return "Error: Model not loaded properly" |
| | |
| | try: |
| | updated_prompt = prompt_template.format(question=question) |
| | inputs = tokenizer(updated_prompt, return_tensors="pt") |
| | |
| | |
| | with torch.no_grad(): |
| | generated_ids = model.generate( |
| | **inputs, |
| | num_return_sequences=1, |
| | eos_token_id=tokenizer.eos_token_id, |
| | pad_token_id=tokenizer.eos_token_id, |
| | max_new_tokens=400, |
| | do_sample=False, |
| | num_beams=1, |
| | temperature=0.0, |
| | top_p=1, |
| | ) |
| | |
| | outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) |
| | |
| | |
| | if "```sql" in outputs[0]: |
| | sql_part = outputs[0].split("```sql")[1].split("```")[0].strip() |
| | else: |
| | |
| | sql_part = outputs[0].split("The following SQL query best answers the question")[1].strip() |
| | if sql_part.startswith("`"): |
| | sql_part = sql_part[1:] |
| | if "```" in sql_part: |
| | sql_part = sql_part.split("```")[0].strip() |
| | |
| | |
| | if sql_part.endswith(";"): |
| | sql_part = sql_part[:-1] |
| | |
| | |
| | formatted_sql = sqlparse.format(sql_part, reindent=True, keyword_case='upper') |
| | return formatted_sql |
| | |
| | except Exception as e: |
| | return f"Error generating query: {str(e)}" |
| |
|
| | def gradio_interface(question): |
| | if not question.strip(): |
| | return "Please enter a question." |
| | |
| | return generate_query(question) |
| |
|
| | |
| | iface = gr.Interface( |
| | fn=gradio_interface, |
| | inputs=gr.Textbox( |
| | label="Question", |
| | placeholder="Enter your question (e.g., 'Show me all expenses for food category')", |
| | lines=3 |
| | ), |
| | outputs=gr.Code(label="Generated SQL Query", language="sql"), |
| | title="SQL Query Generator", |
| | description="Generate SQL queries from natural language questions about expense tracking database.", |
| | examples=[ |
| | ["Show me all expenses for food category"], |
| | ["What's the total amount spent on transport this month?"], |
| | ["Insert a new expense of 50 dollars for groceries on 2024-01-15"], |
| | ["Find users who spent more than 1000 dollars total"], |
| | ["Show me the budget vs actual spending for each category"] |
| | ], |
| | cache_examples=False |
| | ) |
| |
|
| | if __name__ == "__main__": |
| | iface.launch() |