""" AIFinder Dataset Evaluator Supports various HuggingFace dataset formats for evaluation. """ import os import re import json import random from collections import defaultdict from typing import Any from datasets import load_dataset from tqdm import tqdm HF_TOKEN = os.environ.get("HF_TOKEN") SUPPORTED_FORMATS = { "teichai_healer": { "name": "TeichAI Healer Format", "description": "TeichAI Healer-Alpha format with 'prompt' and 'response' fields", "examples": ["TeichAI/Healer-Alpha-16k"], "check": lambda row: ( "prompt" in row and "response" in row and isinstance(row.get("prompt"), (str, dict)) and isinstance(row.get("response"), (str, dict)) ), }, "teichai": { "name": "TeichAI Format", "description": "TeichAI dataset format with 'conversations' or 'messages' containing role/content", "examples": [ "TeichAI/claude-4.5-opus-high-reasoning-250x", "TeichAI/Claude-3.5-Sonnet-128k", ], "check": lambda row: _check_conversations_format(row), }, "combined": { "name": "Combined Outputs", "description": "Dataset with 'output', 'outputs', 'generated' or 'completion' field", "examples": ["jacobmorrison/gpt-oss-20b-combined-outputs"], "check": lambda row: ( "prompt" not in row and "response" not in row and not _check_conversations_format(row) and ( any(k in row for k in ["output", "outputs", "generated", "completion"]) or ( isinstance(row.get("data"), str) or isinstance(row.get("example"), str) ) ) ), }, "conversations": { "name": "Conversations Format", "description": "Dataset with 'conversations' or 'messages' field containing role/content pairs", "examples": [ "TeichAI/claude-4.5-opus-high-reasoning-250x", "ianncity/Hunter-Alpha-SFT-300000x", ], "check": lambda row: _check_conversations_format(row), }, "chat": { "name": "Chat Format", "description": "Dataset with 'chat' or 'dialogue' field", "examples": ["some/chat-dataset"], "check": lambda row: ("chat" in row.keys() or "dialogue" in row.keys()), }, "text": { "name": "Text Field", "description": "Dataset with a 'text' field containing the response", "examples": ["some/text-dataset"], "check": lambda row: "text" in row and isinstance(row.get("text"), str), }, "response": { "name": "Response Field", "description": "Dataset with 'response' or 'output' field", "examples": ["some/response-dataset"], "check": lambda row: "response" in row or "output" in row, }, "content": { "name": "Content Field", "description": "Dataset with 'content' field (single message)", "examples": ["some/content-dataset"], "check": lambda row: "content" in row and isinstance(row.get("content"), str), }, "messages": { "name": "Messages Array", "description": "Dataset where each row is an array of message objects", "examples": ["some/messages-dataset"], "check": lambda row: isinstance(row, list) and len(row) > 0 and isinstance(row[0], dict), }, "sft": { "name": "SFT Format", "description": "Supervised Fine-Tuning format with 'prompt' and 'response' or 'completion'", "examples": ["some/sft-dataset"], "check": lambda row: "prompt" in row and ("response" in row or "completion" in row), }, "qa": { "name": "Q&A Format", "description": "Question-Answer format with 'question' and 'answer' fields", "examples": ["some/qa-dataset"], "check": lambda row: "question" in row and "answer" in row, }, "combined": { "name": "Combined Outputs", "description": "Dataset with 'input', 'output', 'outputs' or combined text field", "examples": ["jacobmorrison/gpt-oss-20b-combined-outputs"], "check": lambda row: any( k in row for k in ["output", "outputs", "combined", "generated", "completion"] ) or (isinstance(row.get("data"), str) or isinstance(row.get("example"), str)), }, "completion": { "name": "Completion Format", "description": "Dataset with 'completion' field (like OpenAI fine-tuning)", "examples": ["some/completion-dataset"], "check": lambda row: "completion" in row and isinstance(row.get("completion"), str), }, "generations": { "name": "Generations Format", "description": "Dataset with 'generations' or 'generation' field (LLM outputs)", "examples": ["some/generations-dataset"], "check": lambda row: "generations" in row or "generation" in row, }, } def _check_conversations_format(row): """Check if row has conversations/messages with proper role/content structure.""" conv_key = ( "conversations" if "conversations" in row else "messages" if "messages" in row else None ) if not conv_key: return False convos = row.get(conv_key) if not isinstance(convos, list) or not convos: return False first_msg = convos[0] if isinstance(first_msg, dict): return "role" in first_msg and "content" in first_msg return False def detect_format(rows, sample_size=10): """Detect the dataset format from sample rows.""" if not rows: return None, [] sample = rows[:sample_size] for fmt_name, fmt_info in SUPPORTED_FORMATS.items(): check_func = fmt_info["check"] matches = 0 for row in sample: try: if check_func(row): matches += 1 except: pass if matches >= len(sample) * 0.6: return fmt_name, SUPPORTED_FORMATS[fmt_name] return None, [] def _parse_msg(msg): """Parse a message that may be a dict or a JSON string.""" if isinstance(msg, dict): return msg if isinstance(msg, str): try: parsed = json.loads(msg) if isinstance(parsed, dict): return parsed except (ValueError, Exception): pass return {} def _extract_response_only(content): """Extract only the final response, stripping CoT blocks.""" if not content: return "" think_match = re.search(r"(.*)$", content, re.DOTALL) if think_match: response = think_match.group(1).strip() if response: return response return content def extract_texts_conversations(rows): """Extract from conversations/messages format.""" texts = [] for row in rows: convos = row.get("conversations") or row.get("messages") or [] if not convos: continue for msg in convos: msg = _parse_msg(msg) role = msg.get("role", "") content = msg.get("content", "") if role in ("assistant", "gpt", "model", "ai") and content: response_only = _extract_response_only(content) if response_only and len(response_only) > 50: texts.append(response_only) return texts def extract_texts_chat(rows): """Extract from chat/dialogue format.""" texts = [] for row in rows: chat = row.get("chat") or row.get("dialogue") or [] if isinstance(chat, list): for msg in chat: msg = _parse_msg(msg) role = msg.get("role", "") content = msg.get("content", "") if role in ("assistant", "ai") and content: response_only = _extract_response_only(content) if response_only and len(response_only) > 50: texts.append(response_only) return texts def extract_texts_text_field(rows, field="text"): """Extract from a text field.""" texts = [] for row in rows: content = row.get(field, "") if content and len(str(content)) > 50: response_only = _extract_response_only(str(content)) if response_only and len(response_only) > 50: texts.append(response_only) return texts def extract_texts_sft(rows): """Extract from SFT format (prompt + response/completion).""" texts = [] for row in rows: response = row.get("response") or row.get("completion") or "" if response and len(str(response)) > 50: response_only = _extract_response_only(str(response)) if response_only and len(response_only) > 50: texts.append(response_only) return texts def extract_texts_qa(rows): """Extract from Q&A format (use answer as response).""" texts = [] for row in rows: answer = row.get("answer", "") if answer and len(str(answer)) > 50: response_only = _extract_response_only(str(answer)) if response_only and len(response_only) > 50: texts.append(response_only) return texts def extract_texts_messages_array(rows): """Extract from messages array format.""" texts = [] for row in rows: if isinstance(row, list): for msg in row: msg = _parse_msg(msg) role = msg.get("role", "") content = msg.get("content", "") if role in ("assistant", "ai", "model") and content: response_only = _extract_response_only(content) if response_only and len(response_only) > 50: texts.append(response_only) return texts def extract_texts_teichai_healer(rows): """Extract from TeichAI Healer-Alpha format (prompt + response fields).""" texts = [] for row in rows: response = row.get("response") if response: if isinstance(response, dict): response = response.get("content") or response.get("text") or "" if response and len(str(response)) > 50: response_only = _extract_response_only(str(response)) if response_only and len(response_only) > 50: texts.append(response_only) return texts def _get_dataset_size(dataset_id, load_kwargs): """Get dataset size without loading all data.""" try: ds = load_dataset(dataset_id, split="train", streaming=True, **load_kwargs) return ds.info.num_rows except Exception: pass try: import pandas as pd url = f"https://huggingface.co/api/datasets/{dataset_id}/parquet/default/train/0.parquet" df = pd.read_parquet(url) return len(df) except Exception: return 0 def _streaming_download_with_progress( dataset_id, load_kwargs, progress_callback=None, max_rows=None ): """Download dataset using streaming with progress tracking.""" import pandas as pd total_rows = _get_dataset_size(dataset_id, load_kwargs) print(f"[PROGRESS] Dataset size: {total_rows} rows", flush=True) download_limit = max_rows if max_rows and max_rows < total_rows else total_rows if progress_callback: progress_callback(0, download_limit, "fetching_info") print(f"[PROGRESS] Initial callback: 0/{download_limit}", flush=True) try: ds = load_dataset(dataset_id, split="train", streaming=True, **load_kwargs) rows = [] for i, row in enumerate(tqdm(ds, desc="Downloading", unit="rows")): rows.append(row) if progress_callback: progress_callback(i + 1, download_limit, "downloading") if i % 100 == 0: print(f"[PROGRESS] Downloaded {i + 1}/{download_limit}", flush=True) if max_rows and i + 1 >= max_rows: print(f"[PROGRESS] Stopping at {i + 1} rows", flush=True) break return rows, min(len(rows), total_rows or len(rows)) except Exception as e: print(f"[PROGRESS] Streaming failed: {e}", flush=True) pass try: url = f"https://huggingface.co/api/datasets/{dataset_id}/parquet/default/train/0.parquet" df = pd.read_parquet(url) if max_rows and max_rows < len(df): df = df.head(max_rows) print(f"[PROGRESS] Limited to first {max_rows} rows", flush=True) total = len(df) if progress_callback: progress_callback(0, total, "downloading") rows = [] for i, row in enumerate(df.to_dict(orient="records")): rows.append(row) if progress_callback: progress_callback(i + 1, total, "downloading") return rows, total except Exception as e: raise e def _load_sample_rows(dataset_id, sample_size, load_kwargs): """Load just a few rows for format detection.""" try: ds = load_dataset(dataset_id, split="train", streaming=True, **load_kwargs) return [next(iter(ds)) for _ in range(sample_size)] except Exception: pass try: import pandas as pd url = f"https://huggingface.co/api/datasets/{dataset_id}/parquet/default/train/0.parquet" df = pd.read_parquet(url) return df.head(sample_size).to_dict(orient="records") except Exception: return [] def load_dataset_texts( dataset_id, max_samples=None, sample_size=None, progress_callback=None, custom_format=None, ): """ Load a HuggingFace dataset and extract assistant response texts. Returns: { "texts": list of extracted texts, "format": detected format name, "format_info": format info dict, "total_rows": total rows in dataset, "supported": bool, "error": error message if failed, } progress_callback: optional function(current, total, stage) -> None stage can be: "fetching_info", "downloading", "extracting" custom_format: optional custom format specification string Examples: - "column: response" - "column: prompt, column: response" - "pattern: user:, pattern: assistant:" - "user:[startuser]assistant:[startassistant]" """ load_kwargs = {"token": HF_TOKEN} if HF_TOKEN else {} rows = [] total_rows = 0 if sample_size: total_rows = _get_dataset_size(dataset_id, load_kwargs) if total_rows == 0: return { "texts": [], "format": None, "format_info": None, "total_rows": 0, "supported": False, "error": "Dataset is empty", } rows = _load_sample_rows(dataset_id, sample_size, load_kwargs) else: if progress_callback: try: rows, total_rows = _streaming_download_with_progress( dataset_id, load_kwargs, progress_callback, max_samples ) except Exception as e: fallback_error = None try: import pandas as pd url = f"https://huggingface.co/api/datasets/{dataset_id}/parquet/default/train/0.parquet" df = pd.read_parquet(url) if max_samples and max_samples < len(df): df = df.head(max_samples) total_rows = len(df) if progress_callback: progress_callback(0, total_rows, "downloading") rows = [] for i, row in enumerate(df.to_dict(orient="records")): rows.append(row) if progress_callback: progress_callback(i + 1, total_rows, "downloading") except Exception as e2: fallback_error = str(e2) return { "texts": [], "format": None, "format_info": None, "total_rows": 0, "supported": False, "error": f"Failed to load: {e}. Parquet fallback also failed: {fallback_error}", } else: try: ds = load_dataset(dataset_id, split="train", **load_kwargs) total_rows = len(ds) if max_samples and max_samples < total_rows: total_rows = max_samples rows = list(ds)[:max_samples] if max_samples else list(ds) except Exception as e: fallback_error = None try: import pandas as pd url = f"https://huggingface.co/api/datasets/{dataset_id}/parquet/default/train/0.parquet" df = pd.read_parquet(url) if max_samples and max_samples < len(df): df = df.head(max_samples) total_rows = len(df) rows = df.to_dict(orient="records") except Exception as e2: fallback_error = str(e2) return { "texts": [], "format": None, "format_info": None, "total_rows": 0, "supported": False, "error": f"Failed to load: {e}. Parquet fallback also failed: {fallback_error}", } if not rows: return { "texts": [], "format": None, "format_info": None, "total_rows": 0, "supported": False, "error": "Dataset is empty", } detect_rows = rows[:sample_size] if sample_size else rows custom_format_spec = custom_format if custom_format_spec and check_custom_format(detect_rows, custom_format_spec): fmt_name = "custom" fmt_info = { "name": "Custom Format", "description": f"Custom format: {custom_format_spec}", "examples": [], } else: fmt_name, fmt_info = detect_format(detect_rows, sample_size=sample_size or 10) if fmt_name is None: return { "texts": [], "format": None, "format_info": None, "total_rows": total_rows, "supported": False, "error": "Unknown dataset format. Supported formats: " + ", ".join(f["name"] for f in SUPPORTED_FORMATS.values()), } extractors = { "teichai_healer": extract_texts_teichai_healer, "teichai": extract_texts_conversations, "conversations": extract_texts_conversations, "chat": extract_texts_chat, "text": lambda r: extract_texts_text_field(r, "text"), "response": lambda r: extract_texts_text_field(r, "response") or extract_texts_text_field(r, "output"), "content": lambda r: extract_texts_text_field(r, "content"), "messages": extract_texts_messages_array, "sft": extract_texts_sft, "qa": extract_texts_qa, "combined": lambda r: ( extract_texts_text_field(r, "output") or extract_texts_text_field(r, "outputs") or extract_texts_text_field(r, "generated") or extract_texts_text_field(r, "completion") or extract_texts_text_field(r, "combined") or extract_texts_text_field(r, "data") or extract_texts_text_field(r, "example") ), "completion": lambda r: extract_texts_text_field(r, "completion"), "generations": lambda r: ( extract_texts_text_field(r, "generations") or extract_texts_text_field(r, "generation") ), "custom": lambda r: extract_texts_custom(r, custom_format_spec), } extractor = extractors.get(fmt_name) texts = extractor(rows) if extractor else [] if max_samples and len(texts) > max_samples: random.seed(42) texts = random.sample(texts, max_samples) return { "texts": texts, "format": fmt_name, "format_info": fmt_info, "total_rows": total_rows, "supported": True, "error": None, } def parse_custom_format_spec(spec): """ Parse custom format specification. Supported formats: - "column: " - extract single field as text - "column: , column: " - extract from two columns (user/assistant) - "pattern: user, pattern: assistant" - use regex patterns - "delimiter: " - use delimiter to split columns Examples: - "column: response" - "column: prompt, column: response" - "pattern: user:, pattern: assistant:" - "user:[startuser]assistant:[startassistant]" """ if not spec: return None spec = spec.strip() result = { "type": None, "user_field": None, "assistant_field": None, "user_pattern": None, "assistant_pattern": None, } if spec.startswith("column:") or spec.startswith("col:"): cols_spec = spec.replace("column:", "").replace("col:", "").strip() if "," in cols_spec: parts = [p.strip() for p in cols_spec.split(",")] if len(parts) >= 2: result["type"] = "two_column" result["user_field"] = parts[0] result["assistant_field"] = parts[1] else: result["type"] = "single_column" result["assistant_field"] = cols_spec return result if spec.startswith("pattern:") or spec.startswith("regex:"): patterns_spec = spec.replace("pattern:", "").replace("regex:", "").strip() if "," in patterns_spec: parts = [p.strip() for p in patterns_spec.split(",")] if len(parts) >= 2: result["type"] = "two_pattern" result["user_pattern"] = parts[0] result["assistant_pattern"] = parts[1] else: result["type"] = "single_pattern" result["assistant_pattern"] = patterns_spec return result if "user:" in spec.lower() and "assistant:" in spec.lower(): import re user_match = re.search( r"user:\s*(\[.*?\]|(?:(?!\s+assistant:).)+)", spec, re.IGNORECASE | re.DOTALL, ) assistant_match = re.search( r"assistant:\s*(\[.*?\]|(?:(?:\s+user:|$).)+)", spec, re.IGNORECASE | re.DOTALL, ) if user_match and assistant_match: result["type"] = "two_pattern" result["user_pattern"] = user_match.group(1).strip() result["assistant_pattern"] = assistant_match.group(1).strip() return result if "[startuser]" in spec and "[startassistant]" in spec: result["type"] = "two_pattern" result["user_pattern"] = re.escape("[startuser]") result["assistant_pattern"] = re.escape("[startassistant]") return result if "," in spec: parts = [p.strip() for p in spec.split(",")] if len(parts) >= 2: result["type"] = "two_column" result["user_field"] = parts[0] result["assistant_field"] = parts[1] return result result["type"] = "single_column" result["assistant_field"] = spec return result def extract_texts_custom(rows, format_spec): """Extract texts using custom format specification.""" parsed = parse_custom_format_spec(format_spec) if not parsed or not parsed.get("type"): return [] texts = [] if parsed["type"] == "single_column": field = parsed["assistant_field"] for row in rows: content = row.get(field, "") if content and len(str(content)) > 50: response_only = _extract_response_only(str(content)) if response_only and len(response_only) > 50: texts.append(response_only) elif parsed["type"] == "two_column": user_field = parsed.get("user_field") assistant_field = parsed["assistant_field"] for row in rows: user_content = row.get(user_field, "") if user_field else "" assistant_content = row.get(assistant_field, "") if assistant_content and len(str(assistant_content)) > 50: response_only = _extract_response_only(str(assistant_content)) if response_only and len(response_only) > 50: texts.append(response_only) elif parsed["type"] == "single_pattern": pattern = parsed.get("assistant_pattern") if pattern: try: regex = re.compile(pattern, re.DOTALL | re.IGNORECASE) for row in rows: row_str = str(row) match = regex.search(row_str) if match: content = match.group(1) if match.groups() else match.group(0) if content and len(content) > 50: response_only = _extract_response_only(content) if response_only and len(response_only) > 50: texts.append(response_only) except re.error: pass elif parsed["type"] == "two_pattern": user_pattern = parsed.get("user_pattern") assistant_pattern = parsed.get("assistant_pattern") if assistant_pattern: try: user_regex = ( re.compile(user_pattern, re.DOTALL | re.IGNORECASE) if user_pattern else None ) assistant_regex = re.compile( assistant_pattern, re.DOTALL | re.IGNORECASE ) for row in rows: row_str = str(row) match = assistant_regex.search(row_str) if match: content = match.group(1) if match.groups() else match.group(0) if content and len(content) > 50: response_only = _extract_response_only(content) if response_only and len(response_only) > 50: texts.append(response_only) except re.error: pass return texts def check_custom_format(rows, format_spec): """Check if custom format applies to the dataset.""" parsed = parse_custom_format_spec(format_spec) if not parsed or not parsed.get("type"): return False if not rows: return False sample = rows[0] if parsed["type"] == "single_column": return parsed.get("assistant_field") in sample if parsed["type"] == "two_column": return parsed.get("assistant_field") in sample if parsed["type"] in ("single_pattern", "two_pattern"): pattern = parsed.get("assistant_pattern") if pattern: try: regex = re.compile(pattern, re.DOTALL | re.IGNORECASE) return regex.search(str(sample)) is not None except re.error: pass return False def get_supported_formats(): """Return list of supported format info.""" return SUPPORTED_FORMATS