| import streamlit as st |
| import pandas as pd |
| import os |
| import base64 |
| from pathlib import Path |
| import matplotlib.pyplot as plt |
| import seaborn as sns |
| import numpy as np |
| from datasets import load_dataset |
|
|
| def load_css(): |
| """Load custom CSS""" |
| with open('styles/custom.css') as f: |
| st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True) |
|
|
| def create_logo(): |
| """Create and display the logo""" |
| from PIL import Image |
| import os |
| |
| |
| logo_path = "assets/python_huggingface_logo.png" |
| |
| |
| if os.path.exists(logo_path): |
| |
| image = Image.open(logo_path) |
| st.image(image, width=200) |
| else: |
| |
| st.markdown( |
| """ |
| <div style="display: flex; justify-content: center; margin-bottom: 20px;"> |
| <h2 style="color: #2196F3;">Python & HuggingFace Explorer</h2> |
| </div> |
| """, |
| unsafe_allow_html=True |
| ) |
|
|
| def get_dataset_info(dataset_name): |
| """Get basic information about a HuggingFace dataset""" |
| if not dataset_name or not isinstance(dataset_name, str): |
| st.error("Invalid dataset name") |
| return None, None |
| |
| try: |
| |
| st.info(f"Loading dataset: {dataset_name}...") |
| |
| try: |
| |
| dataset = load_dataset(dataset_name, streaming=False) |
| |
| first_split = next(iter(dataset.keys())) |
| data = dataset[first_split] |
| except Exception as e: |
| st.warning(f"Couldn't load dataset with default configuration: {str(e)}. Trying specific splits...") |
| |
| for split_name in ["train", "test", "validation"]: |
| try: |
| st.info(f"Trying to load '{split_name}' split...") |
| data = load_dataset(dataset_name, split=split_name, streaming=False) |
| break |
| except Exception as split_error: |
| if split_name == "validation": |
| st.error(f"Failed to load dataset with any standard split: {str(split_error)}") |
| return None, None |
| continue |
| |
| |
| info = { |
| "Dataset": dataset_name, |
| "Number of examples": len(data), |
| "Features": list(data.features.keys()), |
| "Sample": data[0] if len(data) > 0 else None |
| } |
| |
| st.success(f"Successfully loaded dataset with {info['Number of examples']} examples") |
| return info, data |
| except Exception as e: |
| st.error(f"Error loading dataset: {str(e)}") |
| if "Connection error" in str(e) or "timeout" in str(e).lower(): |
| st.warning("Network issue detected. Please check your internet connection and try again.") |
| elif "not found" in str(e).lower(): |
| st.warning(f"Dataset '{dataset_name}' not found. Please check the dataset name and try again.") |
| return None, None |
|
|
| def run_code(code): |
| """Run Python code and capture output""" |
| import io |
| import sys |
| import time |
| from contextlib import redirect_stdout, redirect_stderr |
| |
| |
| stdout_capture = io.StringIO() |
| stderr_capture = io.StringIO() |
| |
| |
| results = { |
| "output": "", |
| "error": "", |
| "figures": [] |
| } |
| |
| |
| if len(code) > 100000: |
| results["error"] = "Code submission too large. Please reduce the size." |
| return results |
| |
| |
| dangerous_imports = ['os.system', 'subprocess', 'eval(', 'shutil.rmtree', 'open(', 'with open'] |
| for dangerous_import in dangerous_imports: |
| if dangerous_import in code: |
| results["error"] = f"Potential security risk: {dangerous_import} is not allowed." |
| return results |
| |
| |
| initial_figs = plt.get_fignums() |
| |
| |
| MAX_EXECUTION_TIME = 30 |
| start_time = time.time() |
| |
| try: |
| |
| safe_globals = { |
| 'plt': plt, |
| 'pd': pd, |
| 'np': np, |
| 'sns': sns, |
| 'print': print, |
| '__builtins__': __builtins__, |
| } |
| |
| |
| for module_name in ['datasets', 'transformers', 'sklearn', 'math']: |
| try: |
| module = __import__(module_name) |
| safe_globals[module_name] = module |
| except ImportError: |
| pass |
| |
| |
| with redirect_stdout(stdout_capture), redirect_stderr(stderr_capture): |
| |
| exec(code, safe_globals) |
| |
| if time.time() - start_time > MAX_EXECUTION_TIME: |
| raise TimeoutError("Code execution exceeded maximum allowed time.") |
| |
| |
| results["output"] = stdout_capture.getvalue() |
| |
| |
| stderr_output = stderr_capture.getvalue() |
| if stderr_output: |
| if results["output"]: |
| results["output"] += "\n\n--- Warnings/Errors ---\n" + stderr_output |
| else: |
| results["output"] = "--- Warnings/Errors ---\n" + stderr_output |
| |
| |
| final_figs = plt.get_fignums() |
| new_figs = set(final_figs) - set(initial_figs) |
| |
| for fig_num in new_figs: |
| fig = plt.figure(fig_num) |
| results["figures"].append(fig) |
| |
| except Exception as e: |
| |
| results["error"] = f"{type(e).__name__}: {str(e)}" |
| |
| return results |
|
|
| def get_dataset_preview(data, max_rows=10): |
| """Convert a HuggingFace dataset to a pandas DataFrame for preview""" |
| try: |
| |
| df = pd.DataFrame(data[:max_rows]) |
| return df |
| except Exception as e: |
| st.error(f"Error converting dataset to DataFrame: {str(e)}") |
| return None |
|
|
| def generate_basic_stats(data): |
| """Generate basic statistics for a dataset""" |
| try: |
| |
| df = pd.DataFrame(data) |
| |
| |
| column_types = df.dtypes |
| |
| |
| stats = {} |
| |
| for col in df.columns: |
| col_stats = {} |
| |
| |
| if pd.api.types.is_numeric_dtype(df[col]): |
| col_stats["mean"] = df[col].mean() |
| col_stats["median"] = df[col].median() |
| col_stats["std"] = df[col].std() |
| col_stats["min"] = df[col].min() |
| col_stats["max"] = df[col].max() |
| col_stats["missing"] = df[col].isna().sum() |
| |
| elif pd.api.types.is_string_dtype(df[col]) or pd.api.types.is_object_dtype(df[col]): |
| col_stats["unique_values"] = df[col].nunique() |
| col_stats["most_common"] = df[col].value_counts().head(5).to_dict() if df[col].nunique() < 100 else "Too many unique values" |
| col_stats["missing"] = df[col].isna().sum() |
| |
| stats[col] = col_stats |
| |
| return stats |
| except Exception as e: |
| st.error(f"Error generating statistics: {str(e)}") |
| return None |
|
|
| def create_visualization(data, viz_type, x_col=None, y_col=None, hue_col=None): |
| """Create a visualization based on the selected type and columns""" |
| try: |
| df = pd.DataFrame(data) |
| |
| fig, ax = plt.subplots(figsize=(10, 6)) |
| |
| if viz_type == "Bar Chart": |
| if x_col and y_col: |
| sns.barplot(x=x_col, y=y_col, hue=hue_col, data=df, ax=ax) |
| else: |
| st.warning("Bar charts require both X and Y columns.") |
| return None |
| |
| elif viz_type == "Line Chart": |
| if x_col and y_col: |
| sns.lineplot(x=x_col, y=y_col, hue=hue_col, data=df, ax=ax) |
| else: |
| st.warning("Line charts require both X and Y columns.") |
| return None |
| |
| elif viz_type == "Scatter Plot": |
| if x_col and y_col: |
| sns.scatterplot(x=x_col, y=y_col, hue=hue_col, data=df, ax=ax) |
| else: |
| st.warning("Scatter plots require both X and Y columns.") |
| return None |
| |
| elif viz_type == "Histogram": |
| if x_col: |
| sns.histplot(df[x_col], ax=ax) |
| else: |
| st.warning("Histograms require an X column.") |
| return None |
| |
| elif viz_type == "Box Plot": |
| if x_col and y_col: |
| sns.boxplot(x=x_col, y=y_col, hue=hue_col, data=df, ax=ax) |
| else: |
| st.warning("Box plots require both X and Y columns.") |
| return None |
| |
| elif viz_type == "Count Plot": |
| if x_col: |
| sns.countplot(x=x_col, hue=hue_col, data=df, ax=ax) |
| else: |
| st.warning("Count plots require an X column.") |
| return None |
| |
| |
| plt.title(f"{viz_type} of {y_col if y_col else ''} vs {x_col if x_col else ''}") |
| plt.xlabel(x_col if x_col else "") |
| plt.ylabel(y_col if y_col else "") |
| plt.tight_layout() |
| |
| return fig |
| |
| except Exception as e: |
| st.error(f"Error creating visualization: {str(e)}") |
| return None |
|
|
| def get_popular_datasets(category=None, limit=10): |
| """Get popular HuggingFace datasets, optionally filtered by category""" |
| popular_datasets = { |
| "Text": ["glue", "imdb", "squad", "wikitext", "ag_news"], |
| "Image": ["cifar10", "cifar100", "mnist", "fashion_mnist", "coco"], |
| "Audio": ["common_voice", "librispeech_asr", "voxpopuli", "voxceleb", "audiofolder"], |
| "Multimodal": ["conceptual_captions", "flickr8k", "hateful_memes", "nlvr", "vqa"] |
| } |
| |
| if category and category in popular_datasets: |
| return popular_datasets[category][:limit] |
| else: |
| |
| all_datasets = [] |
| for cat_datasets in popular_datasets.values(): |
| all_datasets.extend(cat_datasets) |
| return all_datasets[:limit] |
|
|