""" Data synthesis and visualization pipeline using Pandas, Matplotlib, and Seaborn. Determines the best visual representation and saves high-resolution images. Charts use a professional, dashboard-style look with currency formatting and clear typography. """ from pathlib import Path from typing import Any, List, Optional import matplotlib.pyplot as plt import matplotlib.ticker as ticker import pandas as pd import seaborn as sns from src.models import ChartType, VisualizationConfig # Professional palette: slate/blue tones (dashboard-style) CHART_COLOR = "#334155" CHART_COLORS = ["#475569", "#64748b", "#94a3b8", "#0ea5e9", "#0369a1", "#1e40af"] MAX_BAR_CATEGORIES = 15 # cap bars so labels stay readable def _is_currency_column(name: str) -> bool: """Heuristic: treat as currency if name suggests money.""" if not name: return False n = str(name).lower() return any(k in n for k in ("amount", "total", "revenue", "sales", "sum", "price", "value")) def _format_currency_axis(ax, axis="y"): """Format axis with K/M suffix and optional $ for large numbers.""" ax_to_use = ax.yaxis if axis == "y" else ax.xaxis ax_to_use.set_major_formatter( ticker.FuncFormatter(lambda x, p: f"${x/1e3:.0f}K" if abs(x) >= 1000 else f"${x:,.0f}") ) def _setup_professional_style(ax, y_col=None): """Apply consistent style: spine, grid, font.""" ax.spines["top"].set_visible(False) ax.spines["right"].set_visible(False) ax.grid(axis="y", alpha=0.3, linestyle="-") ax.set_axisbelow(True) if y_col and _is_currency_column(y_col): _format_currency_axis(ax, "y") else: ax.yaxis.set_major_formatter(ticker.FuncFormatter(lambda x, p: f"{x:,.0f}")) def infer_visualization(data: List[dict]) -> VisualizationConfig: """ Infer the best visualization type from the data structure. Uses heuristics: column count, types, and sample values. """ if not data: return VisualizationConfig( chart_type=ChartType.BAR, title="No Data", x_column=None, y_column=None, ) df = pd.DataFrame(data) cols = list(df.columns) numeric_cols = df.select_dtypes(include=["number"]).columns.tolist() cat_cols = df.select_dtypes(include=["object", "category"]).columns.tolist() # Heuristics for chart type if len(numeric_cols) >= 2 and len(cat_cols) == 0: return VisualizationConfig( chart_type=ChartType.SCATTER, title="Scatter Plot", x_column=numeric_cols[0], y_column=numeric_cols[1], ) if len(cat_cols) >= 1 and len(numeric_cols) >= 1: return VisualizationConfig( chart_type=ChartType.BAR, title=f"{numeric_cols[0]} by {cat_cols[0]}", x_column=cat_cols[0], y_column=numeric_cols[0], ) if len(numeric_cols) == 1: return VisualizationConfig( chart_type=ChartType.HISTOGRAM, title=f"Distribution of {numeric_cols[0]}", x_column=numeric_cols[0], y_column=None, ) if len(cat_cols) >= 1: return VisualizationConfig( chart_type=ChartType.BAR, title=f"Count of {cat_cols[0]}", x_column=cat_cols[0], y_column=None, ) return VisualizationConfig( chart_type=ChartType.BAR, title="Data Overview", x_column=cols[0] if cols else None, y_column=cols[1] if len(cols) > 1 else None, ) def create_visualization( data: List[dict], config: VisualizationConfig, output_path: str = "output_chart.png", ) -> str: """ Create and save a high-resolution chart based on the config. Returns the path to the saved image. """ df = pd.DataFrame(data) if df.empty: plt.figure(figsize=(8, 5)) plt.text(0.5, 0.5, "No data to visualize", ha="center", va="center", fontsize=14) plt.axis("off") Path(output_path).parent.mkdir(parents=True, exist_ok=True) plt.savefig(output_path, dpi=150, bbox_inches="tight") plt.close() return output_path sns.set_style("ticks") plt.rc("font", size=11) fig, ax = plt.subplots(figsize=(10, 6)) fig.patch.set_facecolor("white") ax.set_facecolor("#fafafa") chart_type = config.chart_type x_col = config.x_column y_col = config.y_column if chart_type == ChartType.LINE and x_col and y_col: df_plot = df.sort_values(x_col).head(100) ax.plot(df_plot[x_col].astype(str), df_plot[y_col], color=CHART_COLOR, marker="o", markersize=5, linewidth=2) _setup_professional_style(ax, y_col) plt.xticks(rotation=45, ha="right") elif chart_type == ChartType.BAR: if y_col: df_plot = df.groupby(x_col, as_index=False)[y_col].sum() df_plot = df_plot.sort_values(y_col, ascending=False).head(MAX_BAR_CATEGORIES) labels = [str(x)[:20] + ("…" if len(str(x)) > 20 else "") for x in df_plot[x_col]] ax.bar(labels, df_plot[y_col], color=CHART_COLOR, edgecolor="white", linewidth=0.5) else: counts = df[x_col].value_counts().head(MAX_BAR_CATEGORIES) labels = [str(x)[:20] + ("…" if len(str(x)) > 20 else "") for x in counts.index] ax.bar(labels, counts.values, color=CHART_COLOR, edgecolor="white", linewidth=0.5) _setup_professional_style(ax, y_col) plt.xticks(rotation=45, ha="right") elif chart_type == ChartType.BARH: if y_col: df_plot = df.groupby(x_col, as_index=False)[y_col].sum() df_plot = df_plot.sort_values(y_col, ascending=True).head(MAX_BAR_CATEGORIES) labels = [str(x)[:25] + ("…" if len(str(x)) > 25 else "") for x in df_plot[x_col]] ax.barh(labels, df_plot[y_col], color=CHART_COLOR, edgecolor="white", linewidth=0.5) else: counts = df[x_col].value_counts().head(MAX_BAR_CATEGORIES) counts = counts.sort_values(ascending=True) labels = [str(x)[:25] + ("…" if len(str(x)) > 25 else "") for x in counts.index] ax.barh(labels, counts.values, color=CHART_COLOR, edgecolor="white", linewidth=0.5) _setup_professional_style(ax, y_col) elif chart_type == ChartType.PIE and x_col: counts = df[x_col].value_counts().head(10) ax.pie(counts.values, labels=counts.index, autopct="%1.1f%%", colors=CHART_COLORS, startangle=90) ax.axis("equal") elif chart_type == ChartType.SCATTER and x_col and y_col: ax.scatter(df[x_col], df[y_col], color=CHART_COLOR, alpha=0.6, s=40) _setup_professional_style(ax, y_col) elif chart_type == ChartType.HISTOGRAM and x_col: ax.hist(df[x_col].dropna(), bins=min(25, len(df)), color=CHART_COLOR, edgecolor="white") _setup_professional_style(ax) elif chart_type == ChartType.BOX and (x_col or y_col): if x_col and y_col: df.boxplot(column=y_col, by=x_col, ax=ax) elif y_col: df.boxplot(column=y_col, ax=ax) _setup_professional_style(ax, y_col) elif chart_type == ChartType.HEATMAP: numeric = df.select_dtypes(include=["number"]) if len(numeric.columns) >= 2: sns.heatmap(numeric.corr(), annot=True, cmap="Blues", ax=ax, fmt=".2f") else: cols = list(df.columns)[:2] if len(cols) == 2: y_vals = df[cols[1]] if pd.api.types.is_numeric_dtype(df[cols[1]]) else list(range(len(df))) labels = [str(x)[:20] + ("…" if len(str(x)) > 20 else "") for x in df[cols[0]]] ax.bar(labels, y_vals, color=CHART_COLOR, edgecolor="white", linewidth=0.5) _setup_professional_style(ax, cols[1] if pd.api.types.is_numeric_dtype(df[cols[1]]) else None) plt.xticks(rotation=45, ha="right") ax.set_title(config.title, fontsize=14, fontweight="600", color="#1e293b") plt.tight_layout() Path(output_path).parent.mkdir(parents=True, exist_ok=True) plt.savefig(output_path, dpi=150, bbox_inches="tight", facecolor="white") plt.close() return output_path