Spaces:
Sleeping
Sleeping
| """ | |
| Data synthesis and visualization pipeline using Pandas, Matplotlib, and Seaborn. | |
| Determines the best visual representation and saves high-resolution images. | |
| Charts use a professional, dashboard-style look with currency formatting and clear typography. | |
| """ | |
| from pathlib import Path | |
| from typing import Any, List, Optional | |
| import matplotlib.pyplot as plt | |
| import matplotlib.ticker as ticker | |
| import pandas as pd | |
| import seaborn as sns | |
| from src.models import ChartType, VisualizationConfig | |
| # Professional palette: slate/blue tones (dashboard-style) | |
| CHART_COLOR = "#334155" | |
| CHART_COLORS = ["#475569", "#64748b", "#94a3b8", "#0ea5e9", "#0369a1", "#1e40af"] | |
| MAX_BAR_CATEGORIES = 15 # cap bars so labels stay readable | |
| def _is_currency_column(name: str) -> bool: | |
| """Heuristic: treat as currency if name suggests money.""" | |
| if not name: | |
| return False | |
| n = str(name).lower() | |
| return any(k in n for k in ("amount", "total", "revenue", "sales", "sum", "price", "value")) | |
| def _format_currency_axis(ax, axis="y"): | |
| """Format axis with K/M suffix and optional $ for large numbers.""" | |
| ax_to_use = ax.yaxis if axis == "y" else ax.xaxis | |
| ax_to_use.set_major_formatter( | |
| ticker.FuncFormatter(lambda x, p: f"${x/1e3:.0f}K" if abs(x) >= 1000 else f"${x:,.0f}") | |
| ) | |
| def _setup_professional_style(ax, y_col=None): | |
| """Apply consistent style: spine, grid, font.""" | |
| ax.spines["top"].set_visible(False) | |
| ax.spines["right"].set_visible(False) | |
| ax.grid(axis="y", alpha=0.3, linestyle="-") | |
| ax.set_axisbelow(True) | |
| if y_col and _is_currency_column(y_col): | |
| _format_currency_axis(ax, "y") | |
| else: | |
| ax.yaxis.set_major_formatter(ticker.FuncFormatter(lambda x, p: f"{x:,.0f}")) | |
| def infer_visualization(data: List[dict]) -> VisualizationConfig: | |
| """ | |
| Infer the best visualization type from the data structure. | |
| Uses heuristics: column count, types, and sample values. | |
| """ | |
| if not data: | |
| return VisualizationConfig( | |
| chart_type=ChartType.BAR, | |
| title="No Data", | |
| x_column=None, | |
| y_column=None, | |
| ) | |
| df = pd.DataFrame(data) | |
| cols = list(df.columns) | |
| numeric_cols = df.select_dtypes(include=["number"]).columns.tolist() | |
| cat_cols = df.select_dtypes(include=["object", "category"]).columns.tolist() | |
| # Heuristics for chart type | |
| if len(numeric_cols) >= 2 and len(cat_cols) == 0: | |
| return VisualizationConfig( | |
| chart_type=ChartType.SCATTER, | |
| title="Scatter Plot", | |
| x_column=numeric_cols[0], | |
| y_column=numeric_cols[1], | |
| ) | |
| if len(cat_cols) >= 1 and len(numeric_cols) >= 1: | |
| return VisualizationConfig( | |
| chart_type=ChartType.BAR, | |
| title=f"{numeric_cols[0]} by {cat_cols[0]}", | |
| x_column=cat_cols[0], | |
| y_column=numeric_cols[0], | |
| ) | |
| if len(numeric_cols) == 1: | |
| return VisualizationConfig( | |
| chart_type=ChartType.HISTOGRAM, | |
| title=f"Distribution of {numeric_cols[0]}", | |
| x_column=numeric_cols[0], | |
| y_column=None, | |
| ) | |
| if len(cat_cols) >= 1: | |
| return VisualizationConfig( | |
| chart_type=ChartType.BAR, | |
| title=f"Count of {cat_cols[0]}", | |
| x_column=cat_cols[0], | |
| y_column=None, | |
| ) | |
| return VisualizationConfig( | |
| chart_type=ChartType.BAR, | |
| title="Data Overview", | |
| x_column=cols[0] if cols else None, | |
| y_column=cols[1] if len(cols) > 1 else None, | |
| ) | |
| def create_visualization( | |
| data: List[dict], | |
| config: VisualizationConfig, | |
| output_path: str = "output_chart.png", | |
| ) -> str: | |
| """ | |
| Create and save a high-resolution chart based on the config. | |
| Returns the path to the saved image. | |
| """ | |
| df = pd.DataFrame(data) | |
| if df.empty: | |
| plt.figure(figsize=(8, 5)) | |
| plt.text(0.5, 0.5, "No data to visualize", ha="center", va="center", fontsize=14) | |
| plt.axis("off") | |
| Path(output_path).parent.mkdir(parents=True, exist_ok=True) | |
| plt.savefig(output_path, dpi=150, bbox_inches="tight") | |
| plt.close() | |
| return output_path | |
| sns.set_style("ticks") | |
| plt.rc("font", size=11) | |
| fig, ax = plt.subplots(figsize=(10, 6)) | |
| fig.patch.set_facecolor("white") | |
| ax.set_facecolor("#fafafa") | |
| chart_type = config.chart_type | |
| x_col = config.x_column | |
| y_col = config.y_column | |
| if chart_type == ChartType.LINE and x_col and y_col: | |
| df_plot = df.sort_values(x_col).head(100) | |
| ax.plot(df_plot[x_col].astype(str), df_plot[y_col], color=CHART_COLOR, marker="o", markersize=5, linewidth=2) | |
| _setup_professional_style(ax, y_col) | |
| plt.xticks(rotation=45, ha="right") | |
| elif chart_type == ChartType.BAR: | |
| if y_col: | |
| df_plot = df.groupby(x_col, as_index=False)[y_col].sum() | |
| df_plot = df_plot.sort_values(y_col, ascending=False).head(MAX_BAR_CATEGORIES) | |
| labels = [str(x)[:20] + ("…" if len(str(x)) > 20 else "") for x in df_plot[x_col]] | |
| ax.bar(labels, df_plot[y_col], color=CHART_COLOR, edgecolor="white", linewidth=0.5) | |
| else: | |
| counts = df[x_col].value_counts().head(MAX_BAR_CATEGORIES) | |
| labels = [str(x)[:20] + ("…" if len(str(x)) > 20 else "") for x in counts.index] | |
| ax.bar(labels, counts.values, color=CHART_COLOR, edgecolor="white", linewidth=0.5) | |
| _setup_professional_style(ax, y_col) | |
| plt.xticks(rotation=45, ha="right") | |
| elif chart_type == ChartType.BARH: | |
| if y_col: | |
| df_plot = df.groupby(x_col, as_index=False)[y_col].sum() | |
| df_plot = df_plot.sort_values(y_col, ascending=True).head(MAX_BAR_CATEGORIES) | |
| labels = [str(x)[:25] + ("…" if len(str(x)) > 25 else "") for x in df_plot[x_col]] | |
| ax.barh(labels, df_plot[y_col], color=CHART_COLOR, edgecolor="white", linewidth=0.5) | |
| else: | |
| counts = df[x_col].value_counts().head(MAX_BAR_CATEGORIES) | |
| counts = counts.sort_values(ascending=True) | |
| labels = [str(x)[:25] + ("…" if len(str(x)) > 25 else "") for x in counts.index] | |
| ax.barh(labels, counts.values, color=CHART_COLOR, edgecolor="white", linewidth=0.5) | |
| _setup_professional_style(ax, y_col) | |
| elif chart_type == ChartType.PIE and x_col: | |
| counts = df[x_col].value_counts().head(10) | |
| ax.pie(counts.values, labels=counts.index, autopct="%1.1f%%", colors=CHART_COLORS, startangle=90) | |
| ax.axis("equal") | |
| elif chart_type == ChartType.SCATTER and x_col and y_col: | |
| ax.scatter(df[x_col], df[y_col], color=CHART_COLOR, alpha=0.6, s=40) | |
| _setup_professional_style(ax, y_col) | |
| elif chart_type == ChartType.HISTOGRAM and x_col: | |
| ax.hist(df[x_col].dropna(), bins=min(25, len(df)), color=CHART_COLOR, edgecolor="white") | |
| _setup_professional_style(ax) | |
| elif chart_type == ChartType.BOX and (x_col or y_col): | |
| if x_col and y_col: | |
| df.boxplot(column=y_col, by=x_col, ax=ax) | |
| elif y_col: | |
| df.boxplot(column=y_col, ax=ax) | |
| _setup_professional_style(ax, y_col) | |
| elif chart_type == ChartType.HEATMAP: | |
| numeric = df.select_dtypes(include=["number"]) | |
| if len(numeric.columns) >= 2: | |
| sns.heatmap(numeric.corr(), annot=True, cmap="Blues", ax=ax, fmt=".2f") | |
| else: | |
| cols = list(df.columns)[:2] | |
| if len(cols) == 2: | |
| y_vals = df[cols[1]] if pd.api.types.is_numeric_dtype(df[cols[1]]) else list(range(len(df))) | |
| labels = [str(x)[:20] + ("…" if len(str(x)) > 20 else "") for x in df[cols[0]]] | |
| ax.bar(labels, y_vals, color=CHART_COLOR, edgecolor="white", linewidth=0.5) | |
| _setup_professional_style(ax, cols[1] if pd.api.types.is_numeric_dtype(df[cols[1]]) else None) | |
| plt.xticks(rotation=45, ha="right") | |
| ax.set_title(config.title, fontsize=14, fontweight="600", color="#1e293b") | |
| plt.tight_layout() | |
| Path(output_path).parent.mkdir(parents=True, exist_ok=True) | |
| plt.savefig(output_path, dpi=150, bbox_inches="tight", facecolor="white") | |
| plt.close() | |
| return output_path | |