| from typing import Dict, List, Any, Optional, Tuple, Union |
| import pandas as pd |
| import matplotlib.pyplot as plt |
| import matplotlib |
| import io |
| import base64 |
| import numpy as np |
| from pathlib import Path |
|
|
| |
| matplotlib.use('Agg') |
|
|
| class VisualizationTools: |
| """Tools for creating visualizations from CSV data.""" |
| |
| def __init__(self, csv_directory: str): |
| """Initialize with directory containing CSV files.""" |
| self.csv_directory = csv_directory |
| self.dataframes = {} |
| self.figure_size = (10, 6) |
| self.dpi = 100 |
| |
| def _load_dataframe(self, filename: str) -> pd.DataFrame: |
| """Load a CSV file as DataFrame, with caching.""" |
| if filename not in self.dataframes: |
| file_path = Path(self.csv_directory) / filename |
| if not file_path.exists() and not filename.endswith('.csv'): |
| file_path = Path(self.csv_directory) / f"{filename}.csv" |
| |
| if file_path.exists(): |
| self.dataframes[filename] = pd.read_csv(file_path) |
| else: |
| raise ValueError(f"CSV file not found: {filename}") |
| |
| return self.dataframes[filename] |
| |
| def get_tools(self) -> List[Dict[str, Any]]: |
| """Get all available visualization tools.""" |
| tools = [ |
| { |
| "name": "create_line_chart", |
| "description": "Create a line chart from CSV data", |
| "function": self.create_line_chart |
| }, |
| { |
| "name": "create_bar_chart", |
| "description": "Create a bar chart from CSV data", |
| "function": self.create_bar_chart |
| }, |
| { |
| "name": "create_scatter_plot", |
| "description": "Create a scatter plot from CSV data", |
| "function": self.create_scatter_plot |
| }, |
| { |
| "name": "create_histogram", |
| "description": "Create a histogram from CSV data", |
| "function": self.create_histogram |
| }, |
| { |
| "name": "create_pie_chart", |
| "description": "Create a pie chart from CSV data", |
| "function": self.create_pie_chart |
| } |
| ] |
| return tools |
| |
| def _figure_to_base64(self, fig) -> str: |
| """Convert matplotlib figure to base64 encoded string.""" |
| buf = io.BytesIO() |
| fig.savefig(buf, format='png', dpi=self.dpi) |
| buf.seek(0) |
| img_str = base64.b64encode(buf.read()).decode('utf-8') |
| plt.close(fig) |
| return img_str |
| |
| |
| def create_line_chart(self, filename: str, x_column: str, y_column: str, |
| title: str = None, limit: int = 50) -> Dict[str, Any]: |
| """Create a line chart visualization.""" |
| df = self._load_dataframe(filename) |
| |
| |
| if len(df) > limit: |
| df = df.head(limit) |
| |
| fig, ax = plt.subplots(figsize=self.figure_size) |
| |
| |
| ax.plot(df[x_column], df[y_column], marker='o', linestyle='-') |
| |
| |
| ax.set_xlabel(x_column) |
| ax.set_ylabel(y_column) |
| ax.set_title(title or f"{y_column} vs {x_column}") |
| ax.grid(True) |
| |
| |
| img_str = self._figure_to_base64(fig) |
| |
| return { |
| "chart_type": "line", |
| "x_column": x_column, |
| "y_column": y_column, |
| "data_points": len(df), |
| "image": img_str |
| } |
| |
| def create_bar_chart(self, filename: str, x_column: str, y_column: str, |
| title: str = None, limit: int = 20) -> Dict[str, Any]: |
| """Create a bar chart visualization.""" |
| df = self._load_dataframe(filename) |
| |
| |
| if len(df) > limit: |
| df = df.head(limit) |
| |
| fig, ax = plt.subplots(figsize=self.figure_size) |
| |
| |
| ax.bar(df[x_column], df[y_column]) |
| |
| |
| ax.set_xlabel(x_column) |
| ax.set_ylabel(y_column) |
| ax.set_title(title or f"{y_column} by {x_column}") |
| |
| |
| if len(df) > 5: |
| plt.xticks(rotation=45, ha='right') |
| |
| plt.tight_layout() |
| |
| |
| img_str = self._figure_to_base64(fig) |
| |
| return { |
| "chart_type": "bar", |
| "x_column": x_column, |
| "y_column": y_column, |
| "categories": len(df), |
| "image": img_str |
| } |
| |
| def create_scatter_plot(self, filename: str, x_column: str, y_column: str, |
| color_column: str = None, title: str = None) -> Dict[str, Any]: |
| """Create a scatter plot visualization.""" |
| df = self._load_dataframe(filename) |
| |
| fig, ax = plt.subplots(figsize=self.figure_size) |
| |
| |
| if color_column and color_column in df.columns: |
| scatter = ax.scatter(df[x_column], df[y_column], c=df[color_column], cmap='viridis', alpha=0.7) |
| plt.colorbar(scatter, ax=ax, label=color_column) |
| else: |
| ax.scatter(df[x_column], df[y_column], alpha=0.7) |
| |
| |
| ax.set_xlabel(x_column) |
| ax.set_ylabel(y_column) |
| ax.set_title(title or f"{y_column} vs {x_column}") |
| ax.grid(True, linestyle='--', alpha=0.7) |
| |
| |
| img_str = self._figure_to_base64(fig) |
| |
| return { |
| "chart_type": "scatter", |
| "x_column": x_column, |
| "y_column": y_column, |
| "color_column": color_column, |
| "data_points": len(df), |
| "image": img_str |
| } |
| |
| def create_histogram(self, filename: str, column: str, bins: int = 10, |
| title: str = None) -> Dict[str, Any]: |
| """Create a histogram visualization.""" |
| df = self._load_dataframe(filename) |
| |
| fig, ax = plt.subplots(figsize=self.figure_size) |
| |
| |
| ax.hist(df[column], bins=bins, alpha=0.7, edgecolor='black') |
| |
| |
| ax.set_xlabel(column) |
| ax.set_ylabel('Frequency') |
| ax.set_title(title or f"Distribution of {column}") |
| ax.grid(True, linestyle='--', alpha=0.7) |
| |
| |
| img_str = self._figure_to_base64(fig) |
| |
| return { |
| "chart_type": "histogram", |
| "column": column, |
| "bins": bins, |
| "data_points": len(df), |
| "image": img_str |
| } |
| |
| def create_pie_chart(self, filename: str, label_column: str, value_column: str = None, |
| title: str = None, limit: int = 10) -> Dict[str, Any]: |
| """Create a pie chart visualization.""" |
| df = self._load_dataframe(filename) |
| |
| |
| if value_column is None: |
| data = df[label_column].value_counts().head(limit) |
| labels = data.index.tolist() |
| values = data.values.tolist() |
| else: |
| |
| grouped = df.groupby(label_column)[value_column].sum().reset_index() |
| |
| grouped = grouped.nlargest(limit, value_column) |
| labels = grouped[label_column].tolist() |
| values = grouped[value_column].tolist() |
| |
| fig, ax = plt.subplots(figsize=self.figure_size) |
| |
| |
| ax.pie(values, labels=labels, autopct='%1.1f%%', startangle=90, shadow=True) |
| ax.axis('equal') |
| |
| |
| ax.set_title(title or f"Distribution of {label_column}") |
| |
| |
| img_str = self._figure_to_base64(fig) |
| |
| return { |
| "chart_type": "pie", |
| "label_column": label_column, |
| "value_column": value_column, |
| "categories": len(labels), |
| "image": img_str |
| } |
|
|