import ast import difflib import json import logging import os import re import time from dotenv import load_dotenv from chart_generator import ChartGenerator from data_processor import DataProcessor load_dotenv() logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Model IDs (downloaded at Docker build, cached in HF_HOME) # --------------------------------------------------------------------------- QWEN_MODEL_ID = os.getenv("QWEN_MODEL_ID", "Qwen/Qwen2.5-Coder-0.5B-Instruct") BART_MODEL_ID = os.getenv("BART_MODEL_ID", "ArchCoder/fine-tuned-bart-large") # --------------------------------------------------------------------------- # Prompt templates with few-shot examples # --------------------------------------------------------------------------- _SYSTEM_PROMPT = """\ You are a data visualization expert. Given the user request and dataset schema, \ output ONLY a valid JSON object. No explanation, no markdown fences, no extra text. Required JSON keys: "x" : string — exact column name for the x-axis "y" : array — one or more exact column names for the y-axis "chart_type" : string — one of: line, bar, scatter, pie, histogram, box, area "color" : string or null — optional CSS color like "red", "#4f8cff" Rules: - Use ONLY column names from the schema. Never invent names. - For pie charts: y must contain exactly one column. - For histogram/box: x may equal the first element of y. - Default to "line" if chart type is ambiguous. ### Examples Example 1: Schema: Year (integer), Sales (float), Profit (float) User: "plot sales over the years with a red line" Output: {"x": "Year", "y": ["Sales"], "chart_type": "line", "color": "red"} Example 2: Schema: Month (string), Revenue (float), Expenses (float) User: "bar chart comparing revenue and expenses by month" Output: {"x": "Month", "y": ["Revenue", "Expenses"], "chart_type": "bar", "color": null} Example 3: Schema: Category (string), Count (integer) User: "pie chart of count by category" Output: {"x": "Category", "y": ["Count"], "chart_type": "pie", "color": null} Example 4: Schema: Date (string), Temperature (float), Humidity (float) User: "scatter plot of temperature vs humidity in blue" Output: {"x": "Temperature", "y": ["Humidity"], "chart_type": "scatter", "color": "blue"} Example 5: Schema: Year (integer), Sales (float), Employee expense (float), Marketing expense (float) User: "show me an area chart of sales and marketing expense over years" Output: {"x": "Year", "y": ["Sales", "Marketing expense"], "chart_type": "area", "color": null} """ def _user_message(query: str, columns: list, dtypes: dict, sample_rows: list) -> str: schema = "\n".join(f" - {c} ({dtypes.get(c, 'unknown')})" for c in columns) samples = "".join(f" {json.dumps(r)}\n" for r in sample_rows[:3]) return ( f"Schema:\n{schema}\n\n" f"Sample rows:\n{samples}\n" f"User: \"{query}\"\n" f"Output:" ) # --------------------------------------------------------------------------- # Output parsing & validation # --------------------------------------------------------------------------- def _parse_output(text: str): text = text.strip() if "```" in text: for part in text.split("```"): part = part.strip().lstrip("json").strip() if part.startswith("{"): text = part break try: return json.loads(text) except json.JSONDecodeError: pass try: return ast.literal_eval(text) except (SyntaxError, ValueError): pass return None def _validate(args: dict, columns: list): if not isinstance(args, dict): return None if not all(k in args for k in ("x", "y", "chart_type")): return None if isinstance(args["y"], str): args["y"] = [args["y"]] valid = {"line", "bar", "scatter", "pie", "histogram", "box", "area"} if args["chart_type"] not in valid: args["chart_type"] = "line" if args["x"] not in columns: return None if not all(c in columns for c in args["y"]): return None return args def _pick_chart_type(query: str) -> str: lowered = query.lower() aliases = { "scatter": ["scatter", "scatterplot"], "bar": ["bar", "column"], "pie": ["pie", "donut"], "histogram": ["histogram", "distribution"], "box": ["box", "boxplot"], "area": ["area"], "line": ["line", "trend", "over time", "over the years"], } for chart_type, keywords in aliases.items(): if any(keyword in lowered for keyword in keywords): return chart_type return "line" def _pick_color(query: str): lowered = query.lower() colors = [ "red", "blue", "green", "yellow", "orange", "purple", "pink", "black", "white", "gray", "grey", "cyan", "teal", "indigo", ] for color in colors: if re.search(rf"\b{re.escape(color)}\b", lowered): return color return None def _pick_columns(query: str, columns: list, dtypes: dict): lowered = query.lower() query_tokens = re.findall(r"[a-zA-Z0-9_]+", lowered) def score_column(column: str) -> float: col_lower = column.lower() score = 0.0 if col_lower in lowered: score += 10.0 for token in query_tokens: if token and token in col_lower: score += 2.0 score += difflib.SequenceMatcher(None, lowered, col_lower).ratio() return score sorted_columns = sorted(columns, key=score_column, reverse=True) numeric_columns = [col for col in columns if dtypes.get(col) in {"integer", "float"}] temporal_columns = [col for col in columns if dtypes.get(col) == "datetime"] year_like = [col for col in columns if "year" in col.lower() or "date" in col.lower() or "month" in col.lower()] x_col = None for candidate in year_like + temporal_columns + sorted_columns: if candidate in columns: x_col = candidate break if x_col is None and columns: x_col = columns[0] y_candidates = [col for col in sorted_columns if col != x_col and col in numeric_columns] if not y_candidates: y_candidates = [col for col in numeric_columns if col != x_col] if not y_candidates: y_candidates = [col for col in columns if col != x_col] return x_col, y_candidates[:1] def _heuristic_plot_args(query: str, columns: list, dtypes: dict) -> dict: x_col, y_cols = _pick_columns(query, columns, dtypes) if not x_col: x_col = "Year" if not y_cols: fallback_y = next((col for col in columns if col != x_col), columns[:1]) y_cols = list(fallback_y) if isinstance(fallback_y, tuple) else fallback_y if isinstance(y_cols, str): y_cols = [y_cols] return { "x": x_col, "y": y_cols, "chart_type": _pick_chart_type(query), "color": _pick_color(query), } # --------------------------------------------------------------------------- # Agent # --------------------------------------------------------------------------- class LLM_Agent: def __init__(self, data_path=None): logger.info("Initializing LLM_Agent") self.data_processor = DataProcessor(data_path) self.chart_generator = ChartGenerator(self.data_processor.data) self._bart_tokenizer = None self._bart_model = None self._qwen_tokenizer = None self._qwen_model = None # -- model runners ------------------------------------------------------- def _run_qwen(self, user_msg: str) -> str: """Qwen2.5-Coder-0.5B-Instruct — fast structured-JSON generation.""" if self._qwen_model is None: from transformers import AutoModelForCausalLM, AutoTokenizer logger.info(f"Loading Qwen model: {QWEN_MODEL_ID}") self._qwen_tokenizer = AutoTokenizer.from_pretrained(QWEN_MODEL_ID) self._qwen_model = AutoModelForCausalLM.from_pretrained(QWEN_MODEL_ID) logger.info("Qwen model loaded.") messages = [ {"role": "system", "content": _SYSTEM_PROMPT}, {"role": "user", "content": user_msg}, ] text = self._qwen_tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = self._qwen_tokenizer(text, return_tensors="pt") outputs = self._qwen_model.generate( **inputs, max_new_tokens=256, temperature=0.1, do_sample=True ) return self._qwen_tokenizer.decode( outputs[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True ) def _run_gemini(self, user_msg: str) -> str: import google.generativeai as genai api_key = os.getenv("GEMINI_API_KEY") if not api_key: raise ValueError("GEMINI_API_KEY is not set") genai.configure(api_key=api_key) model = genai.GenerativeModel( "gemini-2.0-flash", system_instruction=_SYSTEM_PROMPT, ) return model.generate_content(user_msg).text def _run_grok(self, user_msg: str) -> str: from openai import OpenAI api_key = os.getenv("GROK_API_KEY") if not api_key: raise ValueError("GROK_API_KEY is not set") client = OpenAI(api_key=api_key, base_url="https://api.x.ai/v1") resp = client.chat.completions.create( model="grok-3-mini", messages=[ {"role": "system", "content": _SYSTEM_PROMPT}, {"role": "user", "content": user_msg}, ], max_tokens=256, temperature=0.1, ) return resp.choices[0].message.content def _run_bart(self, query: str) -> str: """ArchCoder/fine-tuned-bart-large — lightweight Seq2Seq fallback.""" if self._bart_model is None: from transformers import AutoModelForSeq2SeqLM, AutoTokenizer logger.info(f"Loading BART model: {BART_MODEL_ID}") self._bart_tokenizer = AutoTokenizer.from_pretrained(BART_MODEL_ID) self._bart_model = AutoModelForSeq2SeqLM.from_pretrained(BART_MODEL_ID) logger.info("BART model loaded.") inputs = self._bart_tokenizer( query, return_tensors="pt", max_length=512, truncation=True ) outputs = self._bart_model.generate(**inputs, max_length=100) return self._bart_tokenizer.decode(outputs[0], skip_special_tokens=True) # -- main entry point ---------------------------------------------------- def process_request(self, data: dict) -> dict: t0 = time.time() query = data.get("query", "") data_path = data.get("file_path") model = data.get("model", "qwen") if data_path and os.path.exists(data_path): self.data_processor = DataProcessor(data_path) self.chart_generator = ChartGenerator(self.data_processor.data) columns = self.data_processor.get_columns() dtypes = self.data_processor.get_dtypes() sample_rows = self.data_processor.preview(3) default_args = { "x": columns[0] if columns else "Year", "y": [columns[1]] if len(columns) > 1 else ["Sales"], "chart_type": "line", } raw_text = "" plot_args = None try: user_msg = _user_message(query, columns, dtypes, sample_rows) if model == "gemini": raw_text = self._run_gemini(user_msg) elif model == "grok": raw_text = self._run_grok(user_msg) elif model == "bart": raw_text = self._run_bart(query) elif model == "qwen": try: raw_text = self._run_qwen(user_msg) except Exception as qwen_exc: logger.warning(f"Qwen failed, falling back to BART: {qwen_exc}") raw_text = self._run_bart(query) else: raw_text = self._run_qwen(user_msg) logger.info(f"LLM [{model}] output: {raw_text}") parsed = _parse_output(raw_text) plot_args = _validate(parsed, columns) if parsed else None except Exception as exc: logger.error(f"LLM error [{model}]: {exc}") raw_text = str(exc) if not plot_args: logger.warning("Falling back to heuristic plot args") plot_args = _validate(_heuristic_plot_args(query, columns, dtypes), columns) or default_args try: chart_result = self.chart_generator.generate_chart(plot_args) chart_path = chart_result["chart_path"] chart_spec = chart_result["chart_spec"] except Exception as exc: logger.error(f"Chart generation error: {exc}") return { "response": f"Chart generation failed: {exc}", "chart_path": "", "chart_spec": None, "verified": False, "plot_args": plot_args, } logger.info(f"Request processed in {time.time() - t0:.2f}s") return { "response": json.dumps(plot_args), "chart_path": chart_path, "chart_spec": chart_spec, "verified": True, "plot_args": plot_args, }