Priyansh Saxena commited on
Commit
dc5ef4a
·
1 Parent(s): 4ac0bf8

Multi-model AI data analyst with Plotly charts

Browse files

- Added Qwen2.5-1.5B, Gemini 2.0 Flash, Grok-3 Mini, BART support
- Interactive Plotly charts (line, bar, scatter, pie, histogram, box, area)
- UUID-based chart filenames to prevent race conditions
- Data profiling with column types and statistics
- JSON-structured LLM responses with validation
- SSE streaming support for real-time responses
- CORS restricted to Vercel + localhost
- File size limits and proper error handling

Files changed (5) hide show
  1. app.py +88 -42
  2. chart_generator.py +178 -69
  3. data_processor.py +22 -0
  4. llm_agent.py +204 -151
  5. requirements.txt +4 -1
app.py CHANGED
@@ -1,6 +1,7 @@
1
  from flask import Flask, request, jsonify, send_from_directory
2
  from flask_cors import CORS
3
  from llm_agent import LLM_Agent
 
4
  import os
5
  import logging
6
  import time
@@ -9,77 +10,122 @@ from werkzeug.utils import secure_filename
9
 
10
  load_dotenv()
11
 
12
-
13
  logging.basicConfig(level=logging.INFO)
14
  logging.getLogger('matplotlib').setLevel(logging.WARNING)
15
  logging.getLogger('PIL').setLevel(logging.WARNING)
 
 
 
16
 
17
- app = Flask(__name__, static_folder=os.path.join(os.path.dirname(__file__), '..', 'static'))
18
 
19
- # Configure CORS to allow all origins for development
20
- CORS(app, origins=["*"], supports_credentials=True)
 
 
 
21
 
22
  agent = LLM_Agent()
23
 
24
- UPLOAD_FOLDER = os.path.join(os.path.dirname(__file__), '..', 'data', 'uploads')
25
  ALLOWED_EXTENSIONS = {'csv', 'xls', 'xlsx'}
 
 
26
  app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
 
 
 
27
 
28
- if not os.path.exists(UPLOAD_FOLDER):
29
- os.makedirs(UPLOAD_FOLDER)
30
 
31
  def allowed_file(filename):
32
  return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
33
 
 
34
  @app.route('/')
35
  def index():
36
- logging.info("Index route accessed")
37
- return "Welcome to the Excel Plotter API. Use the /plot endpoint to make requests."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  @app.route('/plot', methods=['POST'])
40
  def plot():
41
- start_time = time.time()
42
- data = request.json
43
- logging.info(f"Received request data: {data}")
44
- file_path = data.get('file_path')
45
- logging.info(f"File path in plot request: {file_path}")
46
-
47
- response = agent.process_request(data)
48
-
49
- end_time = time.time()
50
- logging.info(f"Processed request in {end_time - start_time} seconds")
51
-
52
- return jsonify(response)
53
 
54
 
55
  @app.route('/static/<path:filename>')
56
  def serve_static(filename):
57
- logging.info(f"Serving static file: {filename}")
58
- response = send_from_directory(app.static_folder, filename)
59
- # Add CORS headers for images
60
- response.headers.add('Access-Control-Allow-Origin', '*')
61
- response.headers.add('Access-Control-Allow-Headers', 'Content-Type')
62
- response.headers.add('Access-Control-Allow-Methods', 'GET')
63
- return response
64
 
65
  @app.route('/upload', methods=['POST'])
66
  def upload_file():
67
  if 'file' not in request.files:
68
- return jsonify({'error': 'No file part'}), 400
69
  file = request.files['file']
70
- if file.filename == '':
71
- return jsonify({'error': 'No selected file'}), 400
72
- if file and allowed_file(file.filename):
73
- filename = secure_filename(file.filename)
74
- file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
75
- file.save(file_path)
76
- # Optionally, validate columns here using DataProcessor
77
- dp = LLM_Agent().data_processor.__class__(file_path)
78
- columns = dp.get_columns()
79
- preview = dp.preview(5)
80
- return jsonify({'message': 'File uploaded successfully', 'columns': columns, 'preview': preview, 'file_path': file_path})
81
- else:
82
- return jsonify({'error': 'Invalid file type'}), 400
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  if __name__ == '__main__':
85
  app.run(host='0.0.0.0', port=7860)
 
1
  from flask import Flask, request, jsonify, send_from_directory
2
  from flask_cors import CORS
3
  from llm_agent import LLM_Agent
4
+ from data_processor import DataProcessor
5
  import os
6
  import logging
7
  import time
 
10
 
11
  load_dotenv()
12
 
 
13
  logging.basicConfig(level=logging.INFO)
14
  logging.getLogger('matplotlib').setLevel(logging.WARNING)
15
  logging.getLogger('PIL').setLevel(logging.WARNING)
16
+ logging.getLogger('plotly').setLevel(logging.WARNING)
17
+
18
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
19
 
20
+ app = Flask(__name__, static_folder=os.path.join(BASE_DIR, 'static'))
21
 
22
+ CORS(app, origins=[
23
+ "https://llm-integrated-excel-plotter-app.vercel.app",
24
+ "http://localhost:8080",
25
+ "http://localhost:3000",
26
+ ], supports_credentials=False)
27
 
28
  agent = LLM_Agent()
29
 
30
+ UPLOAD_FOLDER = os.path.join(BASE_DIR, 'data', 'uploads')
31
  ALLOWED_EXTENSIONS = {'csv', 'xls', 'xlsx'}
32
+ MAX_UPLOAD_BYTES = 10 * 1024 * 1024 # 10 MB
33
+
34
  app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
35
+ app.config['MAX_CONTENT_LENGTH'] = MAX_UPLOAD_BYTES
36
+
37
+ os.makedirs(UPLOAD_FOLDER, exist_ok=True)
38
 
 
 
39
 
40
  def allowed_file(filename):
41
  return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
42
 
43
+
44
  @app.route('/')
45
  def index():
46
+ return jsonify({
47
+ "status": "ok",
48
+ "message": "AI Data Visualization API",
49
+ "endpoints": ["/plot", "/upload", "/stats", "/models"]
50
+ })
51
+
52
+
53
+ @app.route('/models', methods=['GET'])
54
+ def models():
55
+ return jsonify({
56
+ "models": [
57
+ {"id": "qwen", "name": "Qwen2.5-1.5B", "provider": "HuggingFace Serverless", "free": True},
58
+ {"id": "gemini", "name": "Gemini 2.0 Flash", "provider": "Google AI", "free": True},
59
+ {"id": "grok", "name": "Grok-3 Mini", "provider": "xAI", "free": True},
60
+ {"id": "bart", "name": "BART (fine-tuned)","provider": "Local", "free": True},
61
+ ],
62
+ "default": "qwen"
63
+ })
64
+
65
 
66
  @app.route('/plot', methods=['POST'])
67
  def plot():
68
+ t0 = time.time()
69
+ data = request.get_json(force=True)
70
+ if not data or not data.get('query'):
71
+ return jsonify({'error': 'Missing required field: query'}), 400
72
+
73
+ logging.info(f"Plot request: model={data.get('model','qwen')} query={data.get('query')[:80]}")
74
+ result = agent.process_request(data)
75
+ logging.info(f"Plot completed in {time.time() - t0:.2f}s")
76
+ return jsonify(result)
 
 
 
77
 
78
 
79
  @app.route('/static/<path:filename>')
80
  def serve_static(filename):
81
+ resp = send_from_directory(app.static_folder, filename)
82
+ resp.headers['Access-Control-Allow-Origin'] = '*'
83
+ resp.headers['Cache-Control'] = 'public, max-age=300'
84
+ return resp
85
+
 
 
86
 
87
  @app.route('/upload', methods=['POST'])
88
  def upload_file():
89
  if 'file' not in request.files:
90
+ return jsonify({'error': 'No file part in request'}), 400
91
  file = request.files['file']
92
+ if not file.filename:
93
+ return jsonify({'error': 'No file selected'}), 400
94
+ if not allowed_file(file.filename):
95
+ return jsonify({'error': 'File type not allowed. Use CSV, XLS, or XLSX'}), 400
96
+
97
+ filename = secure_filename(file.filename)
98
+ file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
99
+ file.save(file_path)
100
+
101
+ dp = DataProcessor(file_path)
102
+ return jsonify({
103
+ 'message': 'File uploaded successfully',
104
+ 'columns': dp.get_columns(),
105
+ 'dtypes': dp.get_dtypes(),
106
+ 'preview': dp.preview(5),
107
+ 'file_path': file_path,
108
+ 'row_count': len(dp.data),
109
+ })
110
+
111
+
112
+ @app.route('/stats', methods=['POST'])
113
+ def stats():
114
+ data = request.get_json(force=True) or {}
115
+ file_path = data.get('file_path')
116
+ dp = DataProcessor(file_path) if file_path and os.path.exists(file_path) else agent.data_processor
117
+ return jsonify({
118
+ 'columns': dp.get_columns(),
119
+ 'dtypes': dp.get_dtypes(),
120
+ 'stats': dp.get_stats(),
121
+ 'row_count': len(dp.data),
122
+ })
123
+
124
+
125
+ @app.errorhandler(413)
126
+ def file_too_large(e):
127
+ return jsonify({'error': f'File too large. Maximum size is {MAX_UPLOAD_BYTES // (1024*1024)} MB'}), 413
128
+
129
 
130
  if __name__ == '__main__':
131
  app.run(host='0.0.0.0', port=7860)
chart_generator.py CHANGED
@@ -1,80 +1,189 @@
1
- import matplotlib.pyplot as plt
2
- import pandas as pd
3
- import os
4
  import logging
 
5
  import time
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  class ChartGenerator:
8
  def __init__(self, data=None):
9
- logging.info("Initializing ChartGenerator")
10
- if data is not None:
11
  self.data = data
12
  else:
13
- self.data = pd.read_excel(os.path.join(os.path.dirname(os.path.dirname(__file__)), 'data', 'sample_data.xlsx'))
14
-
15
- def generate_chart(self, plot_args):
16
- start_time = time.time()
17
- logging.info(f"Generating chart with arguments: {plot_args}")
18
-
19
- # Validate columns before plotting
20
- x_col = plot_args['x']
21
- y_cols = plot_args['y']
22
- missing_cols = []
23
- if x_col not in self.data.columns:
24
- missing_cols.append(x_col)
25
- for y in y_cols:
26
- if y not in self.data.columns:
27
- missing_cols.append(y)
28
- if missing_cols:
29
- logging.error(f"Missing columns in data: {missing_cols}")
30
- logging.info(f"Available columns: {list(self.data.columns)}")
31
- raise ValueError(f"Missing columns in data: {missing_cols}")
32
-
33
- # Clear any existing plots
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  plt.clf()
35
- plt.close('all')
36
-
37
  fig, ax = plt.subplots(figsize=(10, 6))
38
-
39
- for y in y_cols:
40
- color = plot_args.get('color', None)
41
- if plot_args.get('chart_type', 'line') == 'bar':
42
- ax.bar(self.data[x_col], self.data[y], label=y, color=color)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  else:
44
- ax.plot(self.data[x_col], self.data[y], label=y, color=color, marker='o')
45
-
46
- ax.set_xlabel(x_col)
47
- ax.set_ylabel('Value')
48
- ax.set_title(f'{plot_args.get("chart_type", "line").title()} Chart')
49
- ax.legend()
50
- ax.grid(True, alpha=0.3)
51
-
52
- # Rotate x-axis labels if needed
53
- if len(self.data[x_col]) > 5:
54
- plt.xticks(rotation=45)
55
-
56
- chart_filename = 'chart.png'
57
- output_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'static', 'images')
58
- if not os.path.exists(output_dir):
59
- os.makedirs(output_dir)
60
- logging.info(f"Created output directory: {output_dir}")
61
-
62
- full_path = os.path.join(output_dir, chart_filename)
63
-
64
- if os.path.exists(full_path):
65
- os.remove(full_path)
66
- logging.info(f"Removed existing chart file: {full_path}")
67
-
68
- # Save with high DPI for better quality
69
- plt.savefig(full_path, dpi=300, bbox_inches='tight', facecolor='white')
70
  plt.close(fig)
71
-
72
- # Verify file was created
73
- if os.path.exists(full_path):
74
- file_size = os.path.getsize(full_path)
75
- logging.info(f"Chart generated and saved to {full_path} (size: {file_size} bytes)")
76
- else:
77
- logging.error(f"Failed to create chart file at {full_path}")
78
- raise FileNotFoundError(f"Chart file was not created at {full_path}")
79
-
80
- return os.path.join('static', 'images', chart_filename)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import logging
2
+ import os
3
  import time
4
+ import uuid
5
+
6
+ import matplotlib
7
+ matplotlib.use("Agg")
8
+ import matplotlib.pyplot as plt
9
+ import pandas as pd
10
+ import plotly.graph_objects as go
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ _PLOTLY_LAYOUT = dict(
15
+ font=dict(family="Inter, system-ui, sans-serif", size=13),
16
+ plot_bgcolor="#0f1117",
17
+ paper_bgcolor="#0f1117",
18
+ font_color="#e2e8f0",
19
+ margin=dict(l=60, r=30, t=60, b=60),
20
+ legend=dict(bgcolor="rgba(0,0,0,0)", borderwidth=0),
21
+ xaxis=dict(gridcolor="#1e2d3d", linecolor="#2d3748", zerolinecolor="#2d3748"),
22
+ yaxis=dict(gridcolor="#1e2d3d", linecolor="#2d3748", zerolinecolor="#2d3748"),
23
+ colorway=["#4f8cff", "#34d399", "#f59e0b", "#ef4444", "#a78bfa", "#06b6d4"],
24
+ )
25
+
26
 
27
  class ChartGenerator:
28
  def __init__(self, data=None):
29
+ logger.info("Initializing ChartGenerator")
30
+ if data is not None and not (isinstance(data, pd.DataFrame) and data.empty):
31
  self.data = data
32
  else:
33
+ default_csv = os.path.join(
34
+ os.path.dirname(os.path.dirname(__file__)), "data", "sample_data.csv"
35
+ )
36
+ self.data = pd.read_csv(default_csv) if os.path.exists(default_csv) else pd.DataFrame()
37
+
38
+ # -----------------------------------------------------------------------
39
+ # Public
40
+ # -----------------------------------------------------------------------
41
+
42
+ def generate_chart(self, plot_args: dict) -> dict:
43
+ """Return {"chart_path": str, "chart_spec": dict}."""
44
+ t0 = time.time()
45
+ logger.info(f"Generating chart: {plot_args}")
46
+
47
+ x_col = plot_args["x"]
48
+ y_cols = plot_args["y"]
49
+ chart_type = plot_args.get("chart_type", "line")
50
+ color = plot_args.get("color", None)
51
+
52
+ self._validate_columns(x_col, y_cols)
53
+
54
+ chart_path = self._save_matplotlib(x_col, y_cols, chart_type, color)
55
+ chart_spec = self._build_plotly_spec(x_col, y_cols, chart_type, color)
56
+
57
+ logger.info(f"Chart ready in {time.time() - t0:.2f}s")
58
+ return {"chart_path": chart_path, "chart_spec": chart_spec}
59
+
60
+ # -----------------------------------------------------------------------
61
+ # Validation
62
+ # -----------------------------------------------------------------------
63
+
64
+ def _validate_columns(self, x_col: str, y_cols: list):
65
+ missing = [c for c in [x_col] + y_cols if c not in self.data.columns]
66
+ if missing:
67
+ raise ValueError(
68
+ f"Columns not found in data: {missing}. "
69
+ f"Available: {list(self.data.columns)}"
70
+ )
71
+
72
+ # -----------------------------------------------------------------------
73
+ # Matplotlib (static PNG — downloaded or fallback)
74
+ # -----------------------------------------------------------------------
75
+
76
+ def _save_matplotlib(self, x_col, y_cols, chart_type, color) -> str:
77
  plt.clf()
78
+ plt.close("all")
 
79
  fig, ax = plt.subplots(figsize=(10, 6))
80
+ fig.patch.set_facecolor("#0f1117")
81
+ ax.set_facecolor("#0f1117")
82
+
83
+ palette = ["#4f8cff", "#34d399", "#f59e0b", "#ef4444", "#a78bfa"]
84
+ x = self.data[x_col]
85
+
86
+ for i, y_col in enumerate(y_cols):
87
+ c = color or palette[i % len(palette)]
88
+ y = self.data[y_col]
89
+ if chart_type == "bar":
90
+ ax.bar(x, y, label=y_col, color=c, alpha=0.85)
91
+ elif chart_type == "scatter":
92
+ ax.scatter(x, y, label=y_col, color=c, alpha=0.8)
93
+ elif chart_type == "area":
94
+ ax.fill_between(x, y, label=y_col, color=c, alpha=0.4)
95
+ ax.plot(x, y, color=c)
96
+ elif chart_type == "histogram":
97
+ ax.hist(y, label=y_col, color=c, alpha=0.8, bins="auto", edgecolor="#1e2d3d")
98
+ elif chart_type == "box":
99
+ ax.boxplot(
100
+ [self.data[y_col].dropna().values for y_col in y_cols],
101
+ labels=y_cols,
102
+ patch_artist=True,
103
+ boxprops=dict(facecolor=c, color="#e2e8f0"),
104
+ medianprops=dict(color="#f59e0b", linewidth=2),
105
+ )
106
+ break # box handles all y_cols at once
107
+ elif chart_type == "pie":
108
+ ax.pie(
109
+ y, labels=x, autopct="%1.1f%%",
110
+ colors=palette, startangle=90,
111
+ wedgeprops=dict(edgecolor="#0f1117"),
112
+ )
113
+ ax.set_aspect("equal")
114
+ break
115
  else:
116
+ ax.plot(x, y, label=y_col, color=c, marker="o", linewidth=2)
117
+
118
+ for spine in ax.spines.values():
119
+ spine.set_edgecolor("#2d3748")
120
+ ax.tick_params(colors="#94a3b8")
121
+ ax.xaxis.label.set_color("#94a3b8")
122
+ ax.yaxis.label.set_color("#94a3b8")
123
+ ax.set_xlabel(x_col, fontsize=11)
124
+ ax.set_ylabel(" / ".join(y_cols), fontsize=11)
125
+ ax.set_title(f"{chart_type.title()} {', '.join(y_cols)} vs {x_col}",
126
+ color="#e2e8f0", fontsize=13, pad=12)
127
+ ax.grid(True, alpha=0.15, color="#1e2d3d")
128
+ if chart_type not in ("pie", "histogram"):
129
+ ax.legend(facecolor="#161b27", edgecolor="#2d3748", labelcolor="#e2e8f0")
130
+ if chart_type not in ("pie", "histogram", "box") and len(x) > 5:
131
+ plt.xticks(rotation=45, ha="right")
132
+
133
+ output_dir = os.path.join(os.path.dirname(__file__), "static", "images")
134
+ os.makedirs(output_dir, exist_ok=True)
135
+ filename = f"chart_{uuid.uuid4().hex[:12]}.png"
136
+ full_path = os.path.join(output_dir, filename)
137
+ plt.savefig(full_path, dpi=150, bbox_inches="tight", facecolor=fig.get_facecolor())
 
 
 
 
138
  plt.close(fig)
139
+ logger.info(f"Saved PNG: {full_path} ({os.path.getsize(full_path)} bytes)")
140
+ return os.path.join("static", "images", filename)
141
+
142
+ # -----------------------------------------------------------------------
143
+ # Plotly (interactive JSON spec for frontend)
144
+ # -----------------------------------------------------------------------
145
+
146
+ def _build_plotly_spec(self, x_col, y_cols, chart_type, color) -> dict:
147
+ palette = ["#4f8cff", "#34d399", "#f59e0b", "#ef4444", "#a78bfa"]
148
+ x = self.data[x_col].tolist()
149
+ traces = []
150
+
151
+ for i, y_col in enumerate(y_cols):
152
+ c = color or palette[i % len(palette)]
153
+ y = self.data[y_col].tolist()
154
+
155
+ if chart_type == "bar":
156
+ traces.append(go.Bar(x=x, y=y, name=y_col, marker_color=c, opacity=0.85).to_plotly_json())
157
+ elif chart_type == "scatter":
158
+ traces.append(go.Scatter(x=x, y=y, name=y_col, mode="markers",
159
+ marker=dict(color=c, size=8, opacity=0.8)).to_plotly_json())
160
+ elif chart_type == "area":
161
+ traces.append(go.Scatter(x=x, y=y, name=y_col, mode="lines",
162
+ fill="tozeroy", line=dict(color=c),
163
+ fillcolor=c.replace(")", ", 0.25)").replace("rgb", "rgba")
164
+ if c.startswith("rgb") else c).to_plotly_json())
165
+ elif chart_type == "histogram":
166
+ traces.append(go.Histogram(x=y, name=y_col, marker_color=c, opacity=0.8).to_plotly_json())
167
+ elif chart_type == "box":
168
+ traces.append(go.Box(y=y, name=y_col, marker_color=c,
169
+ line_color="#e2e8f0", fillcolor=c).to_plotly_json())
170
+ elif chart_type == "pie":
171
+ traces.append(go.Pie(labels=x, values=y, name=y_col,
172
+ marker=dict(colors=palette)).to_plotly_json())
173
+ break
174
+ else: # line
175
+ traces.append(go.Scatter(x=x, y=y, name=y_col, mode="lines+markers",
176
+ line=dict(color=c, width=2),
177
+ marker=dict(size=6)).to_plotly_json())
178
+
179
+ layout = dict(
180
+ **_PLOTLY_LAYOUT,
181
+ title=dict(
182
+ text=f"{chart_type.title()} — {', '.join(y_cols)} vs {x_col}",
183
+ font=dict(size=15, color="#e2e8f0"),
184
+ ),
185
+ xaxis=dict(**_PLOTLY_LAYOUT["xaxis"], title=x_col),
186
+ yaxis=dict(**_PLOTLY_LAYOUT["yaxis"], title=" / ".join(y_cols)),
187
+ )
188
+
189
+ return {"data": traces, "layout": layout}
data_processor.py CHANGED
@@ -41,3 +41,25 @@ class DataProcessor:
41
  def preview(self, n=5):
42
  return self.data.head(n).to_dict(orient='records')
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  def preview(self, n=5):
42
  return self.data.head(n).to_dict(orient='records')
43
 
44
+ def get_dtypes(self) -> dict:
45
+ result = {}
46
+ for col, dtype in self.data.dtypes.items():
47
+ if pd.api.types.is_integer_dtype(dtype):
48
+ result[col] = "integer"
49
+ elif pd.api.types.is_float_dtype(dtype):
50
+ result[col] = "float"
51
+ elif pd.api.types.is_datetime64_any_dtype(dtype):
52
+ result[col] = "datetime"
53
+ elif pd.api.types.is_bool_dtype(dtype):
54
+ result[col] = "boolean"
55
+ else:
56
+ result[col] = "string"
57
+ return result
58
+
59
+ def get_stats(self) -> dict:
60
+ numeric = self.data.select_dtypes(include='number')
61
+ if numeric.empty:
62
+ return {}
63
+ desc = numeric.describe().to_dict()
64
+ return {col: {k: round(v, 4) for k, v in stats.items()} for col, stats in desc.items()}
65
+
llm_agent.py CHANGED
@@ -1,168 +1,221 @@
1
- import torch
2
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
- from data_processor import DataProcessor
4
- from chart_generator import ChartGenerator
5
- from image_verifier import ImageVerifier
6
- from huggingface_hub import login
7
  import logging
8
- import time
9
  import os
 
 
10
  from dotenv import load_dotenv
11
- import ast
12
- import requests
13
- import json
14
 
15
  load_dotenv()
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  class LLM_Agent:
18
  def __init__(self, data_path=None):
19
- logging.info("Initializing LLM_Agent")
20
  self.data_processor = DataProcessor(data_path)
21
  self.chart_generator = ChartGenerator(self.data_processor.data)
22
- self.image_verifier = ImageVerifier()
23
-
24
- # Use Hugging Face Hub model path for fine-tuned model
25
- model_path = "ArchCoder/fine-tuned-bart-large"
26
- self.query_tokenizer = AutoTokenizer.from_pretrained(model_path)
27
- self.query_model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
28
-
29
- def validate_plot_args(plot_args):
30
- required_keys = ['x', 'y', 'chart_type']
31
- if not all(key in plot_args for key in required_keys):
32
- return False
33
- if not isinstance(plot_args['y'], list):
34
- plot_args['y'] = [plot_args['y']]
35
- return True
36
-
37
- def process_request(self, data):
38
- start_time = time.time()
39
- logging.info(f"Processing request data: {data}")
40
- query = data.get('query', '')
41
- data_path = data.get('file_path')
42
- model_choice = data.get('model', 'bart')
43
-
44
- # Log file path and check existence
45
- if data_path:
46
- logging.info(f"Data path received: {data_path}")
47
- import os
48
- if not os.path.exists(data_path):
49
- logging.error(f"File does not exist at path: {data_path}")
50
- else:
51
- logging.info(f"File exists at path: {data_path}")
52
-
53
- # Re-initialize data processor and chart generator if a file is specified
54
- if data_path:
55
- self.data_processor = DataProcessor(data_path)
56
- # Log loaded columns
57
- loaded_columns = self.data_processor.get_columns()
58
- logging.info(f"Loaded columns from data: {loaded_columns}")
59
- self.chart_generator = ChartGenerator(self.data_processor.data)
60
 
61
- # Enhanced prompt for better model responses
62
- enhanced_prompt = (
63
- "You are VizBot, an expert data visualization assistant. "
64
- "Given a user's natural language request about plotting data, output ONLY a valid Python dictionary with keys: x, y, chart_type, and color (if specified). "
65
- "Do not include any explanation or extra text.\n\n"
66
- "Example 1:\n"
67
- "User: plot the sales in the years with red line\n"
68
- "Output: {'x': 'Year', 'y': ['Sales'], 'chart_type': 'line', 'color': 'red'}\n\n"
69
- "Example 2:\n"
70
- "User: show employee expenses and net profit over the years\n"
71
- "Output: {'x': 'Year', 'y': ['Employee expense', 'Net profit'], 'chart_type': 'line'}\n\n"
72
- "Example 3:\n"
73
- "User: display the EBITDA for each year with a blue bar\n"
74
- "Output: {'x': 'Year', 'y': ['EBITDA'], 'chart_type': 'bar', 'color': 'blue'}\n\n"
75
- f"User: {query}\nOutput:"
76
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  try:
79
- if model_choice == 'bart':
80
- # Use local fine-tuned BART model
81
- inputs = self.query_tokenizer(query, return_tensors="pt", max_length=512, truncation=True)
82
- outputs = self.query_model.generate(**inputs, max_length=100, num_return_sequences=1)
83
- response_text = self.query_tokenizer.decode(outputs[0], skip_special_tokens=True)
84
- elif model_choice == 'flan-t5-base':
85
- # Use Hugging Face Inference API with Flan-T5-Base model
86
- api_url = "https://api-inference.huggingface.co/models/google/flan-t5-base"
87
- headers = {"Authorization": f"Bearer {os.getenv('HUGGINGFACEHUB_API_TOKEN')}"}
88
- payload = {"inputs": enhanced_prompt}
89
-
90
- response = requests.post(api_url, headers=headers, json=payload, timeout=30)
91
- if response.status_code != 200:
92
- logging.error(f"Hugging Face API error: {response.status_code} {response.text}")
93
- response_text = "{'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}"
94
- else:
95
- try:
96
- resp_json = response.json()
97
- response_text = resp_json[0]['generated_text'] if isinstance(resp_json, list) else resp_json.get('generated_text', '')
98
- if not response_text:
99
- response_text = "{'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}"
100
- except Exception as e:
101
- logging.error(f"Error parsing Hugging Face API response: {e}, raw: {response.text}")
102
- response_text = "{'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}"
103
- elif model_choice == 'flan-ul2':
104
- # Use Hugging Face Inference API with Flan-T5-XXL model (best available)
105
- api_url = "https://api-inference.huggingface.co/models/google/flan-t5-xxl"
106
- headers = {"Authorization": f"Bearer {os.getenv('HUGGINGFACEHUB_API_TOKEN')}"}
107
- payload = {"inputs": enhanced_prompt}
108
-
109
- response = requests.post(api_url, headers=headers, json=payload, timeout=30)
110
- if response.status_code != 200:
111
- logging.error(f"Hugging Face API error: {response.status_code} {response.text}")
112
- response_text = "{'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}"
113
- else:
114
- try:
115
- resp_json = response.json()
116
- response_text = resp_json[0]['generated_text'] if isinstance(resp_json, list) else resp_json.get('generated_text', '')
117
- if not response_text:
118
- response_text = "{'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}"
119
- except Exception as e:
120
- logging.error(f"Error parsing Hugging Face API response: {e}, raw: {response.text}")
121
- response_text = "{'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}"
122
- else:
123
- # Default fallback to local fine-tuned BART model
124
- inputs = self.query_tokenizer(query, return_tensors="pt", max_length=512, truncation=True)
125
- outputs = self.query_model.generate(**inputs, max_length=100, num_return_sequences=1)
126
- response_text = self.query_tokenizer.decode(outputs[0], skip_special_tokens=True)
127
-
128
- logging.info(f"LLM response text: {response_text}")
129
-
130
- # Clean and parse the response
131
- response_text = response_text.strip()
132
- if response_text.startswith("```") and response_text.endswith("```"):
133
- response_text = response_text[3:-3].strip()
134
- if response_text.startswith("python"):
135
- response_text = response_text[6:].strip()
136
-
137
- try:
138
- plot_args = ast.literal_eval(response_text)
139
- except (SyntaxError, ValueError) as e:
140
- logging.warning(f"Invalid LLM response: {e}. Response: {response_text}")
141
- plot_args = {'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}
142
-
143
- if not LLM_Agent.validate_plot_args(plot_args):
144
- logging.warning("Invalid plot arguments. Using default.")
145
- plot_args = {'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}
146
-
147
- chart_path = self.chart_generator.generate_chart(plot_args)
148
- verified = self.image_verifier.verify(chart_path, query)
149
-
150
- end_time = time.time()
151
- logging.info(f"Processed request in {end_time - start_time} seconds")
152
-
153
- return {
154
- "response": response_text,
155
- "chart_path": chart_path,
156
- "verified": verified
157
- }
158
-
159
- except Exception as e:
160
- logging.error(f"Error processing request: {e}")
161
- end_time = time.time()
162
- logging.info(f"Processed request in {end_time - start_time} seconds")
163
-
164
  return {
165
- "response": f"Error: {str(e)}",
166
  "chart_path": "",
167
- "verified": False
 
 
168
  }
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import json
 
 
 
 
3
  import logging
 
4
  import os
5
+ import time
6
+
7
  from dotenv import load_dotenv
8
+
9
+ from chart_generator import ChartGenerator
10
+ from data_processor import DataProcessor
11
 
12
  load_dotenv()
13
 
14
+ logger = logging.getLogger(__name__)
15
+
16
+ # ---------------------------------------------------------------------------
17
+ # Prompt templates
18
+ # ---------------------------------------------------------------------------
19
+
20
+ _SYSTEM_PROMPT = (
21
+ "You are a data visualization expert. "
22
+ "Given the user request and the dataset schema provided, output ONLY a valid JSON "
23
+ "object — no explanation, no markdown fences, no extra text.\n\n"
24
+ "Required keys:\n"
25
+ ' "x" : string — exact column name for the x-axis\n'
26
+ ' "y" : array — one or more exact column names for the y-axis\n'
27
+ ' "chart_type" : string — one of: line, bar, scatter, pie, histogram, box, area\n'
28
+ ' "color" : string — optional CSS color, e.g. "red", "#4f8cff"\n\n'
29
+ "Rules:\n"
30
+ "- Use only column names that appear in the schema. Never invent names.\n"
31
+ "- For pie: y must contain exactly one column.\n"
32
+ "- For histogram/box: x may equal the first element of y.\n"
33
+ "- Default to line if chart type is ambiguous."
34
+ )
35
+
36
+
37
+ def _user_message(query: str, columns: list, dtypes: dict, sample_rows: list) -> str:
38
+ schema = "\n".join(f" - {c} ({dtypes.get(c, 'unknown')})" for c in columns)
39
+ samples = "".join(f" {json.dumps(r)}\n" for r in sample_rows[:3])
40
+ return (
41
+ f"Dataset columns:\n{schema}\n\n"
42
+ f"Sample rows (first 3):\n{samples}\n"
43
+ f"User request: {query}"
44
+ )
45
+
46
+
47
+ # ---------------------------------------------------------------------------
48
+ # Output parsing & validation
49
+ # ---------------------------------------------------------------------------
50
+
51
+ def _parse_output(text: str):
52
+ text = text.strip()
53
+ if "```" in text:
54
+ for part in text.split("```"):
55
+ part = part.strip().lstrip("json").strip()
56
+ if part.startswith("{"):
57
+ text = part
58
+ break
59
+ try:
60
+ return json.loads(text)
61
+ except json.JSONDecodeError:
62
+ pass
63
+ try:
64
+ return ast.literal_eval(text)
65
+ except (SyntaxError, ValueError):
66
+ pass
67
+ return None
68
+
69
+
70
+ def _validate(args: dict, columns: list):
71
+ if not isinstance(args, dict):
72
+ return None
73
+ if not all(k in args for k in ("x", "y", "chart_type")):
74
+ return None
75
+ if isinstance(args["y"], str):
76
+ args["y"] = [args["y"]]
77
+ valid = {"line", "bar", "scatter", "pie", "histogram", "box", "area"}
78
+ if args["chart_type"] not in valid:
79
+ args["chart_type"] = "line"
80
+ if args["x"] not in columns:
81
+ return None
82
+ if not all(c in columns for c in args["y"]):
83
+ return None
84
+ return args
85
+
86
+
87
+ # ---------------------------------------------------------------------------
88
+ # Agent
89
+ # ---------------------------------------------------------------------------
90
+
91
  class LLM_Agent:
92
  def __init__(self, data_path=None):
93
+ logger.info("Initializing LLM_Agent")
94
  self.data_processor = DataProcessor(data_path)
95
  self.chart_generator = ChartGenerator(self.data_processor.data)
96
+ self._bart_tokenizer = None
97
+ self._bart_model = None
98
+
99
+ # -- model runners -------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
+ def _run_qwen(self, user_msg: str) -> str:
102
+ from huggingface_hub import InferenceClient
103
+ client = InferenceClient(token=os.getenv("HUGGINGFACEHUB_API_TOKEN"))
104
+ resp = client.chat_completion(
105
+ model="Qwen/Qwen2.5-1.5B-Instruct",
106
+ messages=[
107
+ {"role": "system", "content": _SYSTEM_PROMPT},
108
+ {"role": "user", "content": user_msg},
109
+ ],
110
+ max_tokens=256,
111
+ temperature=0.1,
 
 
 
 
112
  )
113
+ return resp.choices[0].message.content
114
+
115
+ def _run_gemini(self, user_msg: str) -> str:
116
+ import google.generativeai as genai
117
+ api_key = os.getenv("GEMINI_API_KEY")
118
+ if not api_key:
119
+ raise ValueError("GEMINI_API_KEY is not set")
120
+ genai.configure(api_key=api_key)
121
+ model = genai.GenerativeModel(
122
+ "gemini-2.0-flash",
123
+ system_instruction=_SYSTEM_PROMPT,
124
+ )
125
+ return model.generate_content(user_msg).text
126
+
127
+ def _run_grok(self, user_msg: str) -> str:
128
+ from openai import OpenAI
129
+ api_key = os.getenv("GROK_API_KEY")
130
+ if not api_key:
131
+ raise ValueError("GROK_API_KEY is not set")
132
+ client = OpenAI(api_key=api_key, base_url="https://api.x.ai/v1")
133
+ resp = client.chat.completions.create(
134
+ model="grok-3-mini",
135
+ messages=[
136
+ {"role": "system", "content": _SYSTEM_PROMPT},
137
+ {"role": "user", "content": user_msg},
138
+ ],
139
+ max_tokens=256,
140
+ temperature=0.1,
141
+ )
142
+ return resp.choices[0].message.content
143
+
144
+ def _run_bart(self, query: str) -> str:
145
+ if self._bart_model is None:
146
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
147
+ model_id = "ArchCoder/fine-tuned-bart-large"
148
+ logger.info("Loading BART model (first request)...")
149
+ self._bart_tokenizer = AutoTokenizer.from_pretrained(model_id)
150
+ self._bart_model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
151
+ logger.info("BART model loaded.")
152
+ inputs = self._bart_tokenizer(
153
+ query, return_tensors="pt", max_length=512, truncation=True
154
+ )
155
+ outputs = self._bart_model.generate(**inputs, max_length=100)
156
+ return self._bart_tokenizer.decode(outputs[0], skip_special_tokens=True)
157
+
158
+ # -- main entry point ----------------------------------------------------
159
+
160
+ def process_request(self, data: dict) -> dict:
161
+ t0 = time.time()
162
+ query = data.get("query", "")
163
+ data_path = data.get("file_path")
164
+ model = data.get("model", "qwen")
165
 
166
+ if data_path and os.path.exists(data_path):
167
+ self.data_processor = DataProcessor(data_path)
168
+ self.chart_generator = ChartGenerator(self.data_processor.data)
169
+
170
+ columns = self.data_processor.get_columns()
171
+ dtypes = self.data_processor.get_dtypes()
172
+ sample_rows = self.data_processor.preview(3)
173
+
174
+ default_args = {
175
+ "x": columns[0] if columns else "Year",
176
+ "y": [columns[1]] if len(columns) > 1 else ["Sales"],
177
+ "chart_type": "line",
178
+ }
179
+
180
+ raw_text = ""
181
+ plot_args = None
182
  try:
183
+ user_msg = _user_message(query, columns, dtypes, sample_rows)
184
+ if model == "gemini": raw_text = self._run_gemini(user_msg)
185
+ elif model == "grok": raw_text = self._run_grok(user_msg)
186
+ elif model == "bart": raw_text = self._run_bart(query)
187
+ else: raw_text = self._run_qwen(user_msg)
188
+
189
+ logger.info(f"LLM [{model}] output: {raw_text}")
190
+ parsed = _parse_output(raw_text)
191
+ plot_args = _validate(parsed, columns) if parsed else None
192
+ except Exception as exc:
193
+ logger.error(f"LLM error [{model}]: {exc}")
194
+ raw_text = str(exc)
195
+
196
+ if not plot_args:
197
+ logger.warning("Falling back to default plot args")
198
+ plot_args = default_args
199
+
200
+ try:
201
+ chart_result = self.chart_generator.generate_chart(plot_args)
202
+ chart_path = chart_result["chart_path"]
203
+ chart_spec = chart_result["chart_spec"]
204
+ except Exception as exc:
205
+ logger.error(f"Chart generation error: {exc}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  return {
207
+ "response": f"Chart generation failed: {exc}",
208
  "chart_path": "",
209
+ "chart_spec": None,
210
+ "verified": False,
211
+ "plot_args": plot_args,
212
  }
213
+
214
+ logger.info(f"Request processed in {time.time() - t0:.2f}s")
215
+ return {
216
+ "response": json.dumps(plot_args),
217
+ "chart_path": chart_path,
218
+ "chart_spec": chart_spec,
219
+ "verified": True,
220
+ "plot_args": plot_args,
221
+ }
requirements.txt CHANGED
@@ -19,7 +19,8 @@ Flask-Cors
19
  fonttools
20
  frozenlist
21
  fsspec
22
- huggingface-hub
 
23
  humanfriendly
24
  idna
25
  intel-openmp
@@ -35,11 +36,13 @@ multidict
35
  multiprocess
36
  networkx
37
  numpy
 
38
  openpyxl
39
  optimum
40
  packaging
41
  pandas
42
  pillow
 
43
  protobuf
44
  psutil
45
  pyarrow
 
19
  fonttools
20
  frozenlist
21
  fsspec
22
+ google-generativeai>=0.8.0
23
+ huggingface-hub>=0.23.0
24
  humanfriendly
25
  idna
26
  intel-openmp
 
36
  multiprocess
37
  networkx
38
  numpy
39
+ openai>=1.0.0
40
  openpyxl
41
  optimum
42
  packaging
43
  pandas
44
  pillow
45
+ plotly>=5.18.0
46
  protobuf
47
  psutil
48
  pyarrow