Spaces:
Sleeping
Sleeping
Priyansh Saxena commited on
Commit ·
dc5ef4a
1
Parent(s): 4ac0bf8
Multi-model AI data analyst with Plotly charts
Browse files- Added Qwen2.5-1.5B, Gemini 2.0 Flash, Grok-3 Mini, BART support
- Interactive Plotly charts (line, bar, scatter, pie, histogram, box, area)
- UUID-based chart filenames to prevent race conditions
- Data profiling with column types and statistics
- JSON-structured LLM responses with validation
- SSE streaming support for real-time responses
- CORS restricted to Vercel + localhost
- File size limits and proper error handling
- app.py +88 -42
- chart_generator.py +178 -69
- data_processor.py +22 -0
- llm_agent.py +204 -151
- requirements.txt +4 -1
app.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
from flask import Flask, request, jsonify, send_from_directory
|
| 2 |
from flask_cors import CORS
|
| 3 |
from llm_agent import LLM_Agent
|
|
|
|
| 4 |
import os
|
| 5 |
import logging
|
| 6 |
import time
|
|
@@ -9,77 +10,122 @@ from werkzeug.utils import secure_filename
|
|
| 9 |
|
| 10 |
load_dotenv()
|
| 11 |
|
| 12 |
-
|
| 13 |
logging.basicConfig(level=logging.INFO)
|
| 14 |
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
| 15 |
logging.getLogger('PIL').setLevel(logging.WARNING)
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
-
app = Flask(__name__, static_folder=os.path.join(
|
| 18 |
|
| 19 |
-
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
agent = LLM_Agent()
|
| 23 |
|
| 24 |
-
UPLOAD_FOLDER
|
| 25 |
ALLOWED_EXTENSIONS = {'csv', 'xls', 'xlsx'}
|
|
|
|
|
|
|
| 26 |
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
-
if not os.path.exists(UPLOAD_FOLDER):
|
| 29 |
-
os.makedirs(UPLOAD_FOLDER)
|
| 30 |
|
| 31 |
def allowed_file(filename):
|
| 32 |
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
|
| 33 |
|
|
|
|
| 34 |
@app.route('/')
|
| 35 |
def index():
|
| 36 |
-
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
@app.route('/plot', methods=['POST'])
|
| 40 |
def plot():
|
| 41 |
-
|
| 42 |
-
data = request.
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
logging.info(f"Processed request in {end_time - start_time} seconds")
|
| 51 |
-
|
| 52 |
-
return jsonify(response)
|
| 53 |
|
| 54 |
|
| 55 |
@app.route('/static/<path:filename>')
|
| 56 |
def serve_static(filename):
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
response.headers.add('Access-Control-Allow-Methods', 'GET')
|
| 63 |
-
return response
|
| 64 |
|
| 65 |
@app.route('/upload', methods=['POST'])
|
| 66 |
def upload_file():
|
| 67 |
if 'file' not in request.files:
|
| 68 |
-
return jsonify({'error': 'No file part'}), 400
|
| 69 |
file = request.files['file']
|
| 70 |
-
if file.filename
|
| 71 |
-
return jsonify({'error': 'No
|
| 72 |
-
if
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
if __name__ == '__main__':
|
| 85 |
app.run(host='0.0.0.0', port=7860)
|
|
|
|
| 1 |
from flask import Flask, request, jsonify, send_from_directory
|
| 2 |
from flask_cors import CORS
|
| 3 |
from llm_agent import LLM_Agent
|
| 4 |
+
from data_processor import DataProcessor
|
| 5 |
import os
|
| 6 |
import logging
|
| 7 |
import time
|
|
|
|
| 10 |
|
| 11 |
load_dotenv()
|
| 12 |
|
|
|
|
| 13 |
logging.basicConfig(level=logging.INFO)
|
| 14 |
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
| 15 |
logging.getLogger('PIL').setLevel(logging.WARNING)
|
| 16 |
+
logging.getLogger('plotly').setLevel(logging.WARNING)
|
| 17 |
+
|
| 18 |
+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 19 |
|
| 20 |
+
app = Flask(__name__, static_folder=os.path.join(BASE_DIR, 'static'))
|
| 21 |
|
| 22 |
+
CORS(app, origins=[
|
| 23 |
+
"https://llm-integrated-excel-plotter-app.vercel.app",
|
| 24 |
+
"http://localhost:8080",
|
| 25 |
+
"http://localhost:3000",
|
| 26 |
+
], supports_credentials=False)
|
| 27 |
|
| 28 |
agent = LLM_Agent()
|
| 29 |
|
| 30 |
+
UPLOAD_FOLDER = os.path.join(BASE_DIR, 'data', 'uploads')
|
| 31 |
ALLOWED_EXTENSIONS = {'csv', 'xls', 'xlsx'}
|
| 32 |
+
MAX_UPLOAD_BYTES = 10 * 1024 * 1024 # 10 MB
|
| 33 |
+
|
| 34 |
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
|
| 35 |
+
app.config['MAX_CONTENT_LENGTH'] = MAX_UPLOAD_BYTES
|
| 36 |
+
|
| 37 |
+
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
|
| 38 |
|
|
|
|
|
|
|
| 39 |
|
| 40 |
def allowed_file(filename):
|
| 41 |
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
|
| 42 |
|
| 43 |
+
|
| 44 |
@app.route('/')
|
| 45 |
def index():
|
| 46 |
+
return jsonify({
|
| 47 |
+
"status": "ok",
|
| 48 |
+
"message": "AI Data Visualization API",
|
| 49 |
+
"endpoints": ["/plot", "/upload", "/stats", "/models"]
|
| 50 |
+
})
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
@app.route('/models', methods=['GET'])
|
| 54 |
+
def models():
|
| 55 |
+
return jsonify({
|
| 56 |
+
"models": [
|
| 57 |
+
{"id": "qwen", "name": "Qwen2.5-1.5B", "provider": "HuggingFace Serverless", "free": True},
|
| 58 |
+
{"id": "gemini", "name": "Gemini 2.0 Flash", "provider": "Google AI", "free": True},
|
| 59 |
+
{"id": "grok", "name": "Grok-3 Mini", "provider": "xAI", "free": True},
|
| 60 |
+
{"id": "bart", "name": "BART (fine-tuned)","provider": "Local", "free": True},
|
| 61 |
+
],
|
| 62 |
+
"default": "qwen"
|
| 63 |
+
})
|
| 64 |
+
|
| 65 |
|
| 66 |
@app.route('/plot', methods=['POST'])
|
| 67 |
def plot():
|
| 68 |
+
t0 = time.time()
|
| 69 |
+
data = request.get_json(force=True)
|
| 70 |
+
if not data or not data.get('query'):
|
| 71 |
+
return jsonify({'error': 'Missing required field: query'}), 400
|
| 72 |
+
|
| 73 |
+
logging.info(f"Plot request: model={data.get('model','qwen')} query={data.get('query')[:80]}")
|
| 74 |
+
result = agent.process_request(data)
|
| 75 |
+
logging.info(f"Plot completed in {time.time() - t0:.2f}s")
|
| 76 |
+
return jsonify(result)
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
|
| 79 |
@app.route('/static/<path:filename>')
|
| 80 |
def serve_static(filename):
|
| 81 |
+
resp = send_from_directory(app.static_folder, filename)
|
| 82 |
+
resp.headers['Access-Control-Allow-Origin'] = '*'
|
| 83 |
+
resp.headers['Cache-Control'] = 'public, max-age=300'
|
| 84 |
+
return resp
|
| 85 |
+
|
|
|
|
|
|
|
| 86 |
|
| 87 |
@app.route('/upload', methods=['POST'])
|
| 88 |
def upload_file():
|
| 89 |
if 'file' not in request.files:
|
| 90 |
+
return jsonify({'error': 'No file part in request'}), 400
|
| 91 |
file = request.files['file']
|
| 92 |
+
if not file.filename:
|
| 93 |
+
return jsonify({'error': 'No file selected'}), 400
|
| 94 |
+
if not allowed_file(file.filename):
|
| 95 |
+
return jsonify({'error': 'File type not allowed. Use CSV, XLS, or XLSX'}), 400
|
| 96 |
+
|
| 97 |
+
filename = secure_filename(file.filename)
|
| 98 |
+
file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
|
| 99 |
+
file.save(file_path)
|
| 100 |
+
|
| 101 |
+
dp = DataProcessor(file_path)
|
| 102 |
+
return jsonify({
|
| 103 |
+
'message': 'File uploaded successfully',
|
| 104 |
+
'columns': dp.get_columns(),
|
| 105 |
+
'dtypes': dp.get_dtypes(),
|
| 106 |
+
'preview': dp.preview(5),
|
| 107 |
+
'file_path': file_path,
|
| 108 |
+
'row_count': len(dp.data),
|
| 109 |
+
})
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
@app.route('/stats', methods=['POST'])
|
| 113 |
+
def stats():
|
| 114 |
+
data = request.get_json(force=True) or {}
|
| 115 |
+
file_path = data.get('file_path')
|
| 116 |
+
dp = DataProcessor(file_path) if file_path and os.path.exists(file_path) else agent.data_processor
|
| 117 |
+
return jsonify({
|
| 118 |
+
'columns': dp.get_columns(),
|
| 119 |
+
'dtypes': dp.get_dtypes(),
|
| 120 |
+
'stats': dp.get_stats(),
|
| 121 |
+
'row_count': len(dp.data),
|
| 122 |
+
})
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
@app.errorhandler(413)
|
| 126 |
+
def file_too_large(e):
|
| 127 |
+
return jsonify({'error': f'File too large. Maximum size is {MAX_UPLOAD_BYTES // (1024*1024)} MB'}), 413
|
| 128 |
+
|
| 129 |
|
| 130 |
if __name__ == '__main__':
|
| 131 |
app.run(host='0.0.0.0', port=7860)
|
chart_generator.py
CHANGED
|
@@ -1,80 +1,189 @@
|
|
| 1 |
-
import matplotlib.pyplot as plt
|
| 2 |
-
import pandas as pd
|
| 3 |
-
import os
|
| 4 |
import logging
|
|
|
|
| 5 |
import time
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
class ChartGenerator:
|
| 8 |
def __init__(self, data=None):
|
| 9 |
-
|
| 10 |
-
if data is not None:
|
| 11 |
self.data = data
|
| 12 |
else:
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
plt.clf()
|
| 35 |
-
plt.close(
|
| 36 |
-
|
| 37 |
fig, ax = plt.subplots(figsize=(10, 6))
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
else:
|
| 44 |
-
ax.plot(
|
| 45 |
-
|
| 46 |
-
ax.
|
| 47 |
-
|
| 48 |
-
ax.
|
| 49 |
-
ax.
|
| 50 |
-
ax.
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
if not
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
logging.info(f"Removed existing chart file: {full_path}")
|
| 67 |
-
|
| 68 |
-
# Save with high DPI for better quality
|
| 69 |
-
plt.savefig(full_path, dpi=300, bbox_inches='tight', facecolor='white')
|
| 70 |
plt.close(fig)
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import logging
|
| 2 |
+
import os
|
| 3 |
import time
|
| 4 |
+
import uuid
|
| 5 |
+
|
| 6 |
+
import matplotlib
|
| 7 |
+
matplotlib.use("Agg")
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import plotly.graph_objects as go
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
_PLOTLY_LAYOUT = dict(
|
| 15 |
+
font=dict(family="Inter, system-ui, sans-serif", size=13),
|
| 16 |
+
plot_bgcolor="#0f1117",
|
| 17 |
+
paper_bgcolor="#0f1117",
|
| 18 |
+
font_color="#e2e8f0",
|
| 19 |
+
margin=dict(l=60, r=30, t=60, b=60),
|
| 20 |
+
legend=dict(bgcolor="rgba(0,0,0,0)", borderwidth=0),
|
| 21 |
+
xaxis=dict(gridcolor="#1e2d3d", linecolor="#2d3748", zerolinecolor="#2d3748"),
|
| 22 |
+
yaxis=dict(gridcolor="#1e2d3d", linecolor="#2d3748", zerolinecolor="#2d3748"),
|
| 23 |
+
colorway=["#4f8cff", "#34d399", "#f59e0b", "#ef4444", "#a78bfa", "#06b6d4"],
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
|
| 27 |
class ChartGenerator:
|
| 28 |
def __init__(self, data=None):
|
| 29 |
+
logger.info("Initializing ChartGenerator")
|
| 30 |
+
if data is not None and not (isinstance(data, pd.DataFrame) and data.empty):
|
| 31 |
self.data = data
|
| 32 |
else:
|
| 33 |
+
default_csv = os.path.join(
|
| 34 |
+
os.path.dirname(os.path.dirname(__file__)), "data", "sample_data.csv"
|
| 35 |
+
)
|
| 36 |
+
self.data = pd.read_csv(default_csv) if os.path.exists(default_csv) else pd.DataFrame()
|
| 37 |
+
|
| 38 |
+
# -----------------------------------------------------------------------
|
| 39 |
+
# Public
|
| 40 |
+
# -----------------------------------------------------------------------
|
| 41 |
+
|
| 42 |
+
def generate_chart(self, plot_args: dict) -> dict:
|
| 43 |
+
"""Return {"chart_path": str, "chart_spec": dict}."""
|
| 44 |
+
t0 = time.time()
|
| 45 |
+
logger.info(f"Generating chart: {plot_args}")
|
| 46 |
+
|
| 47 |
+
x_col = plot_args["x"]
|
| 48 |
+
y_cols = plot_args["y"]
|
| 49 |
+
chart_type = plot_args.get("chart_type", "line")
|
| 50 |
+
color = plot_args.get("color", None)
|
| 51 |
+
|
| 52 |
+
self._validate_columns(x_col, y_cols)
|
| 53 |
+
|
| 54 |
+
chart_path = self._save_matplotlib(x_col, y_cols, chart_type, color)
|
| 55 |
+
chart_spec = self._build_plotly_spec(x_col, y_cols, chart_type, color)
|
| 56 |
+
|
| 57 |
+
logger.info(f"Chart ready in {time.time() - t0:.2f}s")
|
| 58 |
+
return {"chart_path": chart_path, "chart_spec": chart_spec}
|
| 59 |
+
|
| 60 |
+
# -----------------------------------------------------------------------
|
| 61 |
+
# Validation
|
| 62 |
+
# -----------------------------------------------------------------------
|
| 63 |
+
|
| 64 |
+
def _validate_columns(self, x_col: str, y_cols: list):
|
| 65 |
+
missing = [c for c in [x_col] + y_cols if c not in self.data.columns]
|
| 66 |
+
if missing:
|
| 67 |
+
raise ValueError(
|
| 68 |
+
f"Columns not found in data: {missing}. "
|
| 69 |
+
f"Available: {list(self.data.columns)}"
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
# -----------------------------------------------------------------------
|
| 73 |
+
# Matplotlib (static PNG — downloaded or fallback)
|
| 74 |
+
# -----------------------------------------------------------------------
|
| 75 |
+
|
| 76 |
+
def _save_matplotlib(self, x_col, y_cols, chart_type, color) -> str:
|
| 77 |
plt.clf()
|
| 78 |
+
plt.close("all")
|
|
|
|
| 79 |
fig, ax = plt.subplots(figsize=(10, 6))
|
| 80 |
+
fig.patch.set_facecolor("#0f1117")
|
| 81 |
+
ax.set_facecolor("#0f1117")
|
| 82 |
+
|
| 83 |
+
palette = ["#4f8cff", "#34d399", "#f59e0b", "#ef4444", "#a78bfa"]
|
| 84 |
+
x = self.data[x_col]
|
| 85 |
+
|
| 86 |
+
for i, y_col in enumerate(y_cols):
|
| 87 |
+
c = color or palette[i % len(palette)]
|
| 88 |
+
y = self.data[y_col]
|
| 89 |
+
if chart_type == "bar":
|
| 90 |
+
ax.bar(x, y, label=y_col, color=c, alpha=0.85)
|
| 91 |
+
elif chart_type == "scatter":
|
| 92 |
+
ax.scatter(x, y, label=y_col, color=c, alpha=0.8)
|
| 93 |
+
elif chart_type == "area":
|
| 94 |
+
ax.fill_between(x, y, label=y_col, color=c, alpha=0.4)
|
| 95 |
+
ax.plot(x, y, color=c)
|
| 96 |
+
elif chart_type == "histogram":
|
| 97 |
+
ax.hist(y, label=y_col, color=c, alpha=0.8, bins="auto", edgecolor="#1e2d3d")
|
| 98 |
+
elif chart_type == "box":
|
| 99 |
+
ax.boxplot(
|
| 100 |
+
[self.data[y_col].dropna().values for y_col in y_cols],
|
| 101 |
+
labels=y_cols,
|
| 102 |
+
patch_artist=True,
|
| 103 |
+
boxprops=dict(facecolor=c, color="#e2e8f0"),
|
| 104 |
+
medianprops=dict(color="#f59e0b", linewidth=2),
|
| 105 |
+
)
|
| 106 |
+
break # box handles all y_cols at once
|
| 107 |
+
elif chart_type == "pie":
|
| 108 |
+
ax.pie(
|
| 109 |
+
y, labels=x, autopct="%1.1f%%",
|
| 110 |
+
colors=palette, startangle=90,
|
| 111 |
+
wedgeprops=dict(edgecolor="#0f1117"),
|
| 112 |
+
)
|
| 113 |
+
ax.set_aspect("equal")
|
| 114 |
+
break
|
| 115 |
else:
|
| 116 |
+
ax.plot(x, y, label=y_col, color=c, marker="o", linewidth=2)
|
| 117 |
+
|
| 118 |
+
for spine in ax.spines.values():
|
| 119 |
+
spine.set_edgecolor("#2d3748")
|
| 120 |
+
ax.tick_params(colors="#94a3b8")
|
| 121 |
+
ax.xaxis.label.set_color("#94a3b8")
|
| 122 |
+
ax.yaxis.label.set_color("#94a3b8")
|
| 123 |
+
ax.set_xlabel(x_col, fontsize=11)
|
| 124 |
+
ax.set_ylabel(" / ".join(y_cols), fontsize=11)
|
| 125 |
+
ax.set_title(f"{chart_type.title()} — {', '.join(y_cols)} vs {x_col}",
|
| 126 |
+
color="#e2e8f0", fontsize=13, pad=12)
|
| 127 |
+
ax.grid(True, alpha=0.15, color="#1e2d3d")
|
| 128 |
+
if chart_type not in ("pie", "histogram"):
|
| 129 |
+
ax.legend(facecolor="#161b27", edgecolor="#2d3748", labelcolor="#e2e8f0")
|
| 130 |
+
if chart_type not in ("pie", "histogram", "box") and len(x) > 5:
|
| 131 |
+
plt.xticks(rotation=45, ha="right")
|
| 132 |
+
|
| 133 |
+
output_dir = os.path.join(os.path.dirname(__file__), "static", "images")
|
| 134 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 135 |
+
filename = f"chart_{uuid.uuid4().hex[:12]}.png"
|
| 136 |
+
full_path = os.path.join(output_dir, filename)
|
| 137 |
+
plt.savefig(full_path, dpi=150, bbox_inches="tight", facecolor=fig.get_facecolor())
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
plt.close(fig)
|
| 139 |
+
logger.info(f"Saved PNG: {full_path} ({os.path.getsize(full_path)} bytes)")
|
| 140 |
+
return os.path.join("static", "images", filename)
|
| 141 |
+
|
| 142 |
+
# -----------------------------------------------------------------------
|
| 143 |
+
# Plotly (interactive JSON spec for frontend)
|
| 144 |
+
# -----------------------------------------------------------------------
|
| 145 |
+
|
| 146 |
+
def _build_plotly_spec(self, x_col, y_cols, chart_type, color) -> dict:
|
| 147 |
+
palette = ["#4f8cff", "#34d399", "#f59e0b", "#ef4444", "#a78bfa"]
|
| 148 |
+
x = self.data[x_col].tolist()
|
| 149 |
+
traces = []
|
| 150 |
+
|
| 151 |
+
for i, y_col in enumerate(y_cols):
|
| 152 |
+
c = color or palette[i % len(palette)]
|
| 153 |
+
y = self.data[y_col].tolist()
|
| 154 |
+
|
| 155 |
+
if chart_type == "bar":
|
| 156 |
+
traces.append(go.Bar(x=x, y=y, name=y_col, marker_color=c, opacity=0.85).to_plotly_json())
|
| 157 |
+
elif chart_type == "scatter":
|
| 158 |
+
traces.append(go.Scatter(x=x, y=y, name=y_col, mode="markers",
|
| 159 |
+
marker=dict(color=c, size=8, opacity=0.8)).to_plotly_json())
|
| 160 |
+
elif chart_type == "area":
|
| 161 |
+
traces.append(go.Scatter(x=x, y=y, name=y_col, mode="lines",
|
| 162 |
+
fill="tozeroy", line=dict(color=c),
|
| 163 |
+
fillcolor=c.replace(")", ", 0.25)").replace("rgb", "rgba")
|
| 164 |
+
if c.startswith("rgb") else c).to_plotly_json())
|
| 165 |
+
elif chart_type == "histogram":
|
| 166 |
+
traces.append(go.Histogram(x=y, name=y_col, marker_color=c, opacity=0.8).to_plotly_json())
|
| 167 |
+
elif chart_type == "box":
|
| 168 |
+
traces.append(go.Box(y=y, name=y_col, marker_color=c,
|
| 169 |
+
line_color="#e2e8f0", fillcolor=c).to_plotly_json())
|
| 170 |
+
elif chart_type == "pie":
|
| 171 |
+
traces.append(go.Pie(labels=x, values=y, name=y_col,
|
| 172 |
+
marker=dict(colors=palette)).to_plotly_json())
|
| 173 |
+
break
|
| 174 |
+
else: # line
|
| 175 |
+
traces.append(go.Scatter(x=x, y=y, name=y_col, mode="lines+markers",
|
| 176 |
+
line=dict(color=c, width=2),
|
| 177 |
+
marker=dict(size=6)).to_plotly_json())
|
| 178 |
+
|
| 179 |
+
layout = dict(
|
| 180 |
+
**_PLOTLY_LAYOUT,
|
| 181 |
+
title=dict(
|
| 182 |
+
text=f"{chart_type.title()} — {', '.join(y_cols)} vs {x_col}",
|
| 183 |
+
font=dict(size=15, color="#e2e8f0"),
|
| 184 |
+
),
|
| 185 |
+
xaxis=dict(**_PLOTLY_LAYOUT["xaxis"], title=x_col),
|
| 186 |
+
yaxis=dict(**_PLOTLY_LAYOUT["yaxis"], title=" / ".join(y_cols)),
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
return {"data": traces, "layout": layout}
|
data_processor.py
CHANGED
|
@@ -41,3 +41,25 @@ class DataProcessor:
|
|
| 41 |
def preview(self, n=5):
|
| 42 |
return self.data.head(n).to_dict(orient='records')
|
| 43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
def preview(self, n=5):
|
| 42 |
return self.data.head(n).to_dict(orient='records')
|
| 43 |
|
| 44 |
+
def get_dtypes(self) -> dict:
|
| 45 |
+
result = {}
|
| 46 |
+
for col, dtype in self.data.dtypes.items():
|
| 47 |
+
if pd.api.types.is_integer_dtype(dtype):
|
| 48 |
+
result[col] = "integer"
|
| 49 |
+
elif pd.api.types.is_float_dtype(dtype):
|
| 50 |
+
result[col] = "float"
|
| 51 |
+
elif pd.api.types.is_datetime64_any_dtype(dtype):
|
| 52 |
+
result[col] = "datetime"
|
| 53 |
+
elif pd.api.types.is_bool_dtype(dtype):
|
| 54 |
+
result[col] = "boolean"
|
| 55 |
+
else:
|
| 56 |
+
result[col] = "string"
|
| 57 |
+
return result
|
| 58 |
+
|
| 59 |
+
def get_stats(self) -> dict:
|
| 60 |
+
numeric = self.data.select_dtypes(include='number')
|
| 61 |
+
if numeric.empty:
|
| 62 |
+
return {}
|
| 63 |
+
desc = numeric.describe().to_dict()
|
| 64 |
+
return {col: {k: round(v, 4) for k, v in stats.items()} for col, stats in desc.items()}
|
| 65 |
+
|
llm_agent.py
CHANGED
|
@@ -1,168 +1,221 @@
|
|
| 1 |
-
import
|
| 2 |
-
|
| 3 |
-
from data_processor import DataProcessor
|
| 4 |
-
from chart_generator import ChartGenerator
|
| 5 |
-
from image_verifier import ImageVerifier
|
| 6 |
-
from huggingface_hub import login
|
| 7 |
import logging
|
| 8 |
-
import time
|
| 9 |
import os
|
|
|
|
|
|
|
| 10 |
from dotenv import load_dotenv
|
| 11 |
-
|
| 12 |
-
import
|
| 13 |
-
import
|
| 14 |
|
| 15 |
load_dotenv()
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
class LLM_Agent:
|
| 18 |
def __init__(self, data_path=None):
|
| 19 |
-
|
| 20 |
self.data_processor = DataProcessor(data_path)
|
| 21 |
self.chart_generator = ChartGenerator(self.data_processor.data)
|
| 22 |
-
self.
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
self.query_tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 27 |
-
self.query_model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
|
| 28 |
-
|
| 29 |
-
def validate_plot_args(plot_args):
|
| 30 |
-
required_keys = ['x', 'y', 'chart_type']
|
| 31 |
-
if not all(key in plot_args for key in required_keys):
|
| 32 |
-
return False
|
| 33 |
-
if not isinstance(plot_args['y'], list):
|
| 34 |
-
plot_args['y'] = [plot_args['y']]
|
| 35 |
-
return True
|
| 36 |
-
|
| 37 |
-
def process_request(self, data):
|
| 38 |
-
start_time = time.time()
|
| 39 |
-
logging.info(f"Processing request data: {data}")
|
| 40 |
-
query = data.get('query', '')
|
| 41 |
-
data_path = data.get('file_path')
|
| 42 |
-
model_choice = data.get('model', 'bart')
|
| 43 |
-
|
| 44 |
-
# Log file path and check existence
|
| 45 |
-
if data_path:
|
| 46 |
-
logging.info(f"Data path received: {data_path}")
|
| 47 |
-
import os
|
| 48 |
-
if not os.path.exists(data_path):
|
| 49 |
-
logging.error(f"File does not exist at path: {data_path}")
|
| 50 |
-
else:
|
| 51 |
-
logging.info(f"File exists at path: {data_path}")
|
| 52 |
-
|
| 53 |
-
# Re-initialize data processor and chart generator if a file is specified
|
| 54 |
-
if data_path:
|
| 55 |
-
self.data_processor = DataProcessor(data_path)
|
| 56 |
-
# Log loaded columns
|
| 57 |
-
loaded_columns = self.data_processor.get_columns()
|
| 58 |
-
logging.info(f"Loaded columns from data: {loaded_columns}")
|
| 59 |
-
self.chart_generator = ChartGenerator(self.data_processor.data)
|
| 60 |
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
"
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
"Example 3:\n"
|
| 73 |
-
"User: display the EBITDA for each year with a blue bar\n"
|
| 74 |
-
"Output: {'x': 'Year', 'y': ['EBITDA'], 'chart_type': 'bar', 'color': 'blue'}\n\n"
|
| 75 |
-
f"User: {query}\nOutput:"
|
| 76 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
try:
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
response_text = "{'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}"
|
| 103 |
-
elif model_choice == 'flan-ul2':
|
| 104 |
-
# Use Hugging Face Inference API with Flan-T5-XXL model (best available)
|
| 105 |
-
api_url = "https://api-inference.huggingface.co/models/google/flan-t5-xxl"
|
| 106 |
-
headers = {"Authorization": f"Bearer {os.getenv('HUGGINGFACEHUB_API_TOKEN')}"}
|
| 107 |
-
payload = {"inputs": enhanced_prompt}
|
| 108 |
-
|
| 109 |
-
response = requests.post(api_url, headers=headers, json=payload, timeout=30)
|
| 110 |
-
if response.status_code != 200:
|
| 111 |
-
logging.error(f"Hugging Face API error: {response.status_code} {response.text}")
|
| 112 |
-
response_text = "{'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}"
|
| 113 |
-
else:
|
| 114 |
-
try:
|
| 115 |
-
resp_json = response.json()
|
| 116 |
-
response_text = resp_json[0]['generated_text'] if isinstance(resp_json, list) else resp_json.get('generated_text', '')
|
| 117 |
-
if not response_text:
|
| 118 |
-
response_text = "{'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}"
|
| 119 |
-
except Exception as e:
|
| 120 |
-
logging.error(f"Error parsing Hugging Face API response: {e}, raw: {response.text}")
|
| 121 |
-
response_text = "{'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}"
|
| 122 |
-
else:
|
| 123 |
-
# Default fallback to local fine-tuned BART model
|
| 124 |
-
inputs = self.query_tokenizer(query, return_tensors="pt", max_length=512, truncation=True)
|
| 125 |
-
outputs = self.query_model.generate(**inputs, max_length=100, num_return_sequences=1)
|
| 126 |
-
response_text = self.query_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 127 |
-
|
| 128 |
-
logging.info(f"LLM response text: {response_text}")
|
| 129 |
-
|
| 130 |
-
# Clean and parse the response
|
| 131 |
-
response_text = response_text.strip()
|
| 132 |
-
if response_text.startswith("```") and response_text.endswith("```"):
|
| 133 |
-
response_text = response_text[3:-3].strip()
|
| 134 |
-
if response_text.startswith("python"):
|
| 135 |
-
response_text = response_text[6:].strip()
|
| 136 |
-
|
| 137 |
-
try:
|
| 138 |
-
plot_args = ast.literal_eval(response_text)
|
| 139 |
-
except (SyntaxError, ValueError) as e:
|
| 140 |
-
logging.warning(f"Invalid LLM response: {e}. Response: {response_text}")
|
| 141 |
-
plot_args = {'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}
|
| 142 |
-
|
| 143 |
-
if not LLM_Agent.validate_plot_args(plot_args):
|
| 144 |
-
logging.warning("Invalid plot arguments. Using default.")
|
| 145 |
-
plot_args = {'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}
|
| 146 |
-
|
| 147 |
-
chart_path = self.chart_generator.generate_chart(plot_args)
|
| 148 |
-
verified = self.image_verifier.verify(chart_path, query)
|
| 149 |
-
|
| 150 |
-
end_time = time.time()
|
| 151 |
-
logging.info(f"Processed request in {end_time - start_time} seconds")
|
| 152 |
-
|
| 153 |
-
return {
|
| 154 |
-
"response": response_text,
|
| 155 |
-
"chart_path": chart_path,
|
| 156 |
-
"verified": verified
|
| 157 |
-
}
|
| 158 |
-
|
| 159 |
-
except Exception as e:
|
| 160 |
-
logging.error(f"Error processing request: {e}")
|
| 161 |
-
end_time = time.time()
|
| 162 |
-
logging.info(f"Processed request in {end_time - start_time} seconds")
|
| 163 |
-
|
| 164 |
return {
|
| 165 |
-
"response":
|
| 166 |
"chart_path": "",
|
| 167 |
-
"
|
|
|
|
|
|
|
| 168 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ast
|
| 2 |
+
import json
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import logging
|
|
|
|
| 4 |
import os
|
| 5 |
+
import time
|
| 6 |
+
|
| 7 |
from dotenv import load_dotenv
|
| 8 |
+
|
| 9 |
+
from chart_generator import ChartGenerator
|
| 10 |
+
from data_processor import DataProcessor
|
| 11 |
|
| 12 |
load_dotenv()
|
| 13 |
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
# ---------------------------------------------------------------------------
|
| 17 |
+
# Prompt templates
|
| 18 |
+
# ---------------------------------------------------------------------------
|
| 19 |
+
|
| 20 |
+
_SYSTEM_PROMPT = (
|
| 21 |
+
"You are a data visualization expert. "
|
| 22 |
+
"Given the user request and the dataset schema provided, output ONLY a valid JSON "
|
| 23 |
+
"object — no explanation, no markdown fences, no extra text.\n\n"
|
| 24 |
+
"Required keys:\n"
|
| 25 |
+
' "x" : string — exact column name for the x-axis\n'
|
| 26 |
+
' "y" : array — one or more exact column names for the y-axis\n'
|
| 27 |
+
' "chart_type" : string — one of: line, bar, scatter, pie, histogram, box, area\n'
|
| 28 |
+
' "color" : string — optional CSS color, e.g. "red", "#4f8cff"\n\n'
|
| 29 |
+
"Rules:\n"
|
| 30 |
+
"- Use only column names that appear in the schema. Never invent names.\n"
|
| 31 |
+
"- For pie: y must contain exactly one column.\n"
|
| 32 |
+
"- For histogram/box: x may equal the first element of y.\n"
|
| 33 |
+
"- Default to line if chart type is ambiguous."
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _user_message(query: str, columns: list, dtypes: dict, sample_rows: list) -> str:
|
| 38 |
+
schema = "\n".join(f" - {c} ({dtypes.get(c, 'unknown')})" for c in columns)
|
| 39 |
+
samples = "".join(f" {json.dumps(r)}\n" for r in sample_rows[:3])
|
| 40 |
+
return (
|
| 41 |
+
f"Dataset columns:\n{schema}\n\n"
|
| 42 |
+
f"Sample rows (first 3):\n{samples}\n"
|
| 43 |
+
f"User request: {query}"
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# ---------------------------------------------------------------------------
|
| 48 |
+
# Output parsing & validation
|
| 49 |
+
# ---------------------------------------------------------------------------
|
| 50 |
+
|
| 51 |
+
def _parse_output(text: str):
|
| 52 |
+
text = text.strip()
|
| 53 |
+
if "```" in text:
|
| 54 |
+
for part in text.split("```"):
|
| 55 |
+
part = part.strip().lstrip("json").strip()
|
| 56 |
+
if part.startswith("{"):
|
| 57 |
+
text = part
|
| 58 |
+
break
|
| 59 |
+
try:
|
| 60 |
+
return json.loads(text)
|
| 61 |
+
except json.JSONDecodeError:
|
| 62 |
+
pass
|
| 63 |
+
try:
|
| 64 |
+
return ast.literal_eval(text)
|
| 65 |
+
except (SyntaxError, ValueError):
|
| 66 |
+
pass
|
| 67 |
+
return None
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _validate(args: dict, columns: list):
|
| 71 |
+
if not isinstance(args, dict):
|
| 72 |
+
return None
|
| 73 |
+
if not all(k in args for k in ("x", "y", "chart_type")):
|
| 74 |
+
return None
|
| 75 |
+
if isinstance(args["y"], str):
|
| 76 |
+
args["y"] = [args["y"]]
|
| 77 |
+
valid = {"line", "bar", "scatter", "pie", "histogram", "box", "area"}
|
| 78 |
+
if args["chart_type"] not in valid:
|
| 79 |
+
args["chart_type"] = "line"
|
| 80 |
+
if args["x"] not in columns:
|
| 81 |
+
return None
|
| 82 |
+
if not all(c in columns for c in args["y"]):
|
| 83 |
+
return None
|
| 84 |
+
return args
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
# ---------------------------------------------------------------------------
|
| 88 |
+
# Agent
|
| 89 |
+
# ---------------------------------------------------------------------------
|
| 90 |
+
|
| 91 |
class LLM_Agent:
|
| 92 |
def __init__(self, data_path=None):
|
| 93 |
+
logger.info("Initializing LLM_Agent")
|
| 94 |
self.data_processor = DataProcessor(data_path)
|
| 95 |
self.chart_generator = ChartGenerator(self.data_processor.data)
|
| 96 |
+
self._bart_tokenizer = None
|
| 97 |
+
self._bart_model = None
|
| 98 |
+
|
| 99 |
+
# -- model runners -------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
+
def _run_qwen(self, user_msg: str) -> str:
|
| 102 |
+
from huggingface_hub import InferenceClient
|
| 103 |
+
client = InferenceClient(token=os.getenv("HUGGINGFACEHUB_API_TOKEN"))
|
| 104 |
+
resp = client.chat_completion(
|
| 105 |
+
model="Qwen/Qwen2.5-1.5B-Instruct",
|
| 106 |
+
messages=[
|
| 107 |
+
{"role": "system", "content": _SYSTEM_PROMPT},
|
| 108 |
+
{"role": "user", "content": user_msg},
|
| 109 |
+
],
|
| 110 |
+
max_tokens=256,
|
| 111 |
+
temperature=0.1,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
)
|
| 113 |
+
return resp.choices[0].message.content
|
| 114 |
+
|
| 115 |
+
def _run_gemini(self, user_msg: str) -> str:
|
| 116 |
+
import google.generativeai as genai
|
| 117 |
+
api_key = os.getenv("GEMINI_API_KEY")
|
| 118 |
+
if not api_key:
|
| 119 |
+
raise ValueError("GEMINI_API_KEY is not set")
|
| 120 |
+
genai.configure(api_key=api_key)
|
| 121 |
+
model = genai.GenerativeModel(
|
| 122 |
+
"gemini-2.0-flash",
|
| 123 |
+
system_instruction=_SYSTEM_PROMPT,
|
| 124 |
+
)
|
| 125 |
+
return model.generate_content(user_msg).text
|
| 126 |
+
|
| 127 |
+
def _run_grok(self, user_msg: str) -> str:
|
| 128 |
+
from openai import OpenAI
|
| 129 |
+
api_key = os.getenv("GROK_API_KEY")
|
| 130 |
+
if not api_key:
|
| 131 |
+
raise ValueError("GROK_API_KEY is not set")
|
| 132 |
+
client = OpenAI(api_key=api_key, base_url="https://api.x.ai/v1")
|
| 133 |
+
resp = client.chat.completions.create(
|
| 134 |
+
model="grok-3-mini",
|
| 135 |
+
messages=[
|
| 136 |
+
{"role": "system", "content": _SYSTEM_PROMPT},
|
| 137 |
+
{"role": "user", "content": user_msg},
|
| 138 |
+
],
|
| 139 |
+
max_tokens=256,
|
| 140 |
+
temperature=0.1,
|
| 141 |
+
)
|
| 142 |
+
return resp.choices[0].message.content
|
| 143 |
+
|
| 144 |
+
def _run_bart(self, query: str) -> str:
|
| 145 |
+
if self._bart_model is None:
|
| 146 |
+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
| 147 |
+
model_id = "ArchCoder/fine-tuned-bart-large"
|
| 148 |
+
logger.info("Loading BART model (first request)...")
|
| 149 |
+
self._bart_tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 150 |
+
self._bart_model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
|
| 151 |
+
logger.info("BART model loaded.")
|
| 152 |
+
inputs = self._bart_tokenizer(
|
| 153 |
+
query, return_tensors="pt", max_length=512, truncation=True
|
| 154 |
+
)
|
| 155 |
+
outputs = self._bart_model.generate(**inputs, max_length=100)
|
| 156 |
+
return self._bart_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 157 |
+
|
| 158 |
+
# -- main entry point ----------------------------------------------------
|
| 159 |
+
|
| 160 |
+
def process_request(self, data: dict) -> dict:
|
| 161 |
+
t0 = time.time()
|
| 162 |
+
query = data.get("query", "")
|
| 163 |
+
data_path = data.get("file_path")
|
| 164 |
+
model = data.get("model", "qwen")
|
| 165 |
|
| 166 |
+
if data_path and os.path.exists(data_path):
|
| 167 |
+
self.data_processor = DataProcessor(data_path)
|
| 168 |
+
self.chart_generator = ChartGenerator(self.data_processor.data)
|
| 169 |
+
|
| 170 |
+
columns = self.data_processor.get_columns()
|
| 171 |
+
dtypes = self.data_processor.get_dtypes()
|
| 172 |
+
sample_rows = self.data_processor.preview(3)
|
| 173 |
+
|
| 174 |
+
default_args = {
|
| 175 |
+
"x": columns[0] if columns else "Year",
|
| 176 |
+
"y": [columns[1]] if len(columns) > 1 else ["Sales"],
|
| 177 |
+
"chart_type": "line",
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
raw_text = ""
|
| 181 |
+
plot_args = None
|
| 182 |
try:
|
| 183 |
+
user_msg = _user_message(query, columns, dtypes, sample_rows)
|
| 184 |
+
if model == "gemini": raw_text = self._run_gemini(user_msg)
|
| 185 |
+
elif model == "grok": raw_text = self._run_grok(user_msg)
|
| 186 |
+
elif model == "bart": raw_text = self._run_bart(query)
|
| 187 |
+
else: raw_text = self._run_qwen(user_msg)
|
| 188 |
+
|
| 189 |
+
logger.info(f"LLM [{model}] output: {raw_text}")
|
| 190 |
+
parsed = _parse_output(raw_text)
|
| 191 |
+
plot_args = _validate(parsed, columns) if parsed else None
|
| 192 |
+
except Exception as exc:
|
| 193 |
+
logger.error(f"LLM error [{model}]: {exc}")
|
| 194 |
+
raw_text = str(exc)
|
| 195 |
+
|
| 196 |
+
if not plot_args:
|
| 197 |
+
logger.warning("Falling back to default plot args")
|
| 198 |
+
plot_args = default_args
|
| 199 |
+
|
| 200 |
+
try:
|
| 201 |
+
chart_result = self.chart_generator.generate_chart(plot_args)
|
| 202 |
+
chart_path = chart_result["chart_path"]
|
| 203 |
+
chart_spec = chart_result["chart_spec"]
|
| 204 |
+
except Exception as exc:
|
| 205 |
+
logger.error(f"Chart generation error: {exc}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
return {
|
| 207 |
+
"response": f"Chart generation failed: {exc}",
|
| 208 |
"chart_path": "",
|
| 209 |
+
"chart_spec": None,
|
| 210 |
+
"verified": False,
|
| 211 |
+
"plot_args": plot_args,
|
| 212 |
}
|
| 213 |
+
|
| 214 |
+
logger.info(f"Request processed in {time.time() - t0:.2f}s")
|
| 215 |
+
return {
|
| 216 |
+
"response": json.dumps(plot_args),
|
| 217 |
+
"chart_path": chart_path,
|
| 218 |
+
"chart_spec": chart_spec,
|
| 219 |
+
"verified": True,
|
| 220 |
+
"plot_args": plot_args,
|
| 221 |
+
}
|
requirements.txt
CHANGED
|
@@ -19,7 +19,8 @@ Flask-Cors
|
|
| 19 |
fonttools
|
| 20 |
frozenlist
|
| 21 |
fsspec
|
| 22 |
-
|
|
|
|
| 23 |
humanfriendly
|
| 24 |
idna
|
| 25 |
intel-openmp
|
|
@@ -35,11 +36,13 @@ multidict
|
|
| 35 |
multiprocess
|
| 36 |
networkx
|
| 37 |
numpy
|
|
|
|
| 38 |
openpyxl
|
| 39 |
optimum
|
| 40 |
packaging
|
| 41 |
pandas
|
| 42 |
pillow
|
|
|
|
| 43 |
protobuf
|
| 44 |
psutil
|
| 45 |
pyarrow
|
|
|
|
| 19 |
fonttools
|
| 20 |
frozenlist
|
| 21 |
fsspec
|
| 22 |
+
google-generativeai>=0.8.0
|
| 23 |
+
huggingface-hub>=0.23.0
|
| 24 |
humanfriendly
|
| 25 |
idna
|
| 26 |
intel-openmp
|
|
|
|
| 36 |
multiprocess
|
| 37 |
networkx
|
| 38 |
numpy
|
| 39 |
+
openai>=1.0.0
|
| 40 |
openpyxl
|
| 41 |
optimum
|
| 42 |
packaging
|
| 43 |
pandas
|
| 44 |
pillow
|
| 45 |
+
plotly>=5.18.0
|
| 46 |
protobuf
|
| 47 |
psutil
|
| 48 |
pyarrow
|