AIFinder / dataset_evaluator.py
CompactAI's picture
Upload 18 files
bb0efe6 verified
"""
AIFinder Dataset Evaluator
Supports various HuggingFace dataset formats for evaluation.
"""
import os
import re
import json
import random
from collections import defaultdict
from typing import Any
from datasets import load_dataset
from tqdm import tqdm
HF_TOKEN = os.environ.get("HF_TOKEN")
SUPPORTED_FORMATS = {
"teichai_healer": {
"name": "TeichAI Healer Format",
"description": "TeichAI Healer-Alpha format with 'prompt' and 'response' fields",
"examples": ["TeichAI/Healer-Alpha-16k"],
"check": lambda row: (
"prompt" in row
and "response" in row
and isinstance(row.get("prompt"), (str, dict))
and isinstance(row.get("response"), (str, dict))
),
},
"teichai": {
"name": "TeichAI Format",
"description": "TeichAI dataset format with 'conversations' or 'messages' containing role/content",
"examples": [
"TeichAI/claude-4.5-opus-high-reasoning-250x",
"TeichAI/Claude-3.5-Sonnet-128k",
],
"check": lambda row: _check_conversations_format(row),
},
"combined": {
"name": "Combined Outputs",
"description": "Dataset with 'output', 'outputs', 'generated' or 'completion' field",
"examples": ["jacobmorrison/gpt-oss-20b-combined-outputs"],
"check": lambda row: (
"prompt" not in row
and "response" not in row
and not _check_conversations_format(row)
and (
any(k in row for k in ["output", "outputs", "generated", "completion"])
or (
isinstance(row.get("data"), str)
or isinstance(row.get("example"), str)
)
)
),
},
"conversations": {
"name": "Conversations Format",
"description": "Dataset with 'conversations' or 'messages' field containing role/content pairs",
"examples": [
"TeichAI/claude-4.5-opus-high-reasoning-250x",
"ianncity/Hunter-Alpha-SFT-300000x",
],
"check": lambda row: _check_conversations_format(row),
},
"chat": {
"name": "Chat Format",
"description": "Dataset with 'chat' or 'dialogue' field",
"examples": ["some/chat-dataset"],
"check": lambda row: ("chat" in row.keys() or "dialogue" in row.keys()),
},
"text": {
"name": "Text Field",
"description": "Dataset with a 'text' field containing the response",
"examples": ["some/text-dataset"],
"check": lambda row: "text" in row and isinstance(row.get("text"), str),
},
"response": {
"name": "Response Field",
"description": "Dataset with 'response' or 'output' field",
"examples": ["some/response-dataset"],
"check": lambda row: "response" in row or "output" in row,
},
"content": {
"name": "Content Field",
"description": "Dataset with 'content' field (single message)",
"examples": ["some/content-dataset"],
"check": lambda row: "content" in row and isinstance(row.get("content"), str),
},
"messages": {
"name": "Messages Array",
"description": "Dataset where each row is an array of message objects",
"examples": ["some/messages-dataset"],
"check": lambda row: isinstance(row, list)
and len(row) > 0
and isinstance(row[0], dict),
},
"sft": {
"name": "SFT Format",
"description": "Supervised Fine-Tuning format with 'prompt' and 'response' or 'completion'",
"examples": ["some/sft-dataset"],
"check": lambda row: "prompt" in row
and ("response" in row or "completion" in row),
},
"qa": {
"name": "Q&A Format",
"description": "Question-Answer format with 'question' and 'answer' fields",
"examples": ["some/qa-dataset"],
"check": lambda row: "question" in row and "answer" in row,
},
"combined": {
"name": "Combined Outputs",
"description": "Dataset with 'input', 'output', 'outputs' or combined text field",
"examples": ["jacobmorrison/gpt-oss-20b-combined-outputs"],
"check": lambda row: any(
k in row
for k in ["output", "outputs", "combined", "generated", "completion"]
)
or (isinstance(row.get("data"), str) or isinstance(row.get("example"), str)),
},
"completion": {
"name": "Completion Format",
"description": "Dataset with 'completion' field (like OpenAI fine-tuning)",
"examples": ["some/completion-dataset"],
"check": lambda row: "completion" in row
and isinstance(row.get("completion"), str),
},
"generations": {
"name": "Generations Format",
"description": "Dataset with 'generations' or 'generation' field (LLM outputs)",
"examples": ["some/generations-dataset"],
"check": lambda row: "generations" in row or "generation" in row,
},
}
def _check_conversations_format(row):
"""Check if row has conversations/messages with proper role/content structure."""
conv_key = (
"conversations"
if "conversations" in row
else "messages"
if "messages" in row
else None
)
if not conv_key:
return False
convos = row.get(conv_key)
if not isinstance(convos, list) or not convos:
return False
first_msg = convos[0]
if isinstance(first_msg, dict):
return "role" in first_msg and "content" in first_msg
return False
def detect_format(rows, sample_size=10):
"""Detect the dataset format from sample rows."""
if not rows:
return None, []
sample = rows[:sample_size]
for fmt_name, fmt_info in SUPPORTED_FORMATS.items():
check_func = fmt_info["check"]
matches = 0
for row in sample:
try:
if check_func(row):
matches += 1
except:
pass
if matches >= len(sample) * 0.6:
return fmt_name, SUPPORTED_FORMATS[fmt_name]
return None, []
def _parse_msg(msg):
"""Parse a message that may be a dict or a JSON string."""
if isinstance(msg, dict):
return msg
if isinstance(msg, str):
try:
parsed = json.loads(msg)
if isinstance(parsed, dict):
return parsed
except (ValueError, Exception):
pass
return {}
def _extract_response_only(content):
"""Extract only the final response, stripping CoT blocks."""
if not content:
return ""
think_match = re.search(r"</?think(?:ing)?>(.*)$", content, re.DOTALL)
if think_match:
response = think_match.group(1).strip()
if response:
return response
return content
def extract_texts_conversations(rows):
"""Extract from conversations/messages format."""
texts = []
for row in rows:
convos = row.get("conversations") or row.get("messages") or []
if not convos:
continue
for msg in convos:
msg = _parse_msg(msg)
role = msg.get("role", "")
content = msg.get("content", "")
if role in ("assistant", "gpt", "model", "ai") and content:
response_only = _extract_response_only(content)
if response_only and len(response_only) > 50:
texts.append(response_only)
return texts
def extract_texts_chat(rows):
"""Extract from chat/dialogue format."""
texts = []
for row in rows:
chat = row.get("chat") or row.get("dialogue") or []
if isinstance(chat, list):
for msg in chat:
msg = _parse_msg(msg)
role = msg.get("role", "")
content = msg.get("content", "")
if role in ("assistant", "ai") and content:
response_only = _extract_response_only(content)
if response_only and len(response_only) > 50:
texts.append(response_only)
return texts
def extract_texts_text_field(rows, field="text"):
"""Extract from a text field."""
texts = []
for row in rows:
content = row.get(field, "")
if content and len(str(content)) > 50:
response_only = _extract_response_only(str(content))
if response_only and len(response_only) > 50:
texts.append(response_only)
return texts
def extract_texts_sft(rows):
"""Extract from SFT format (prompt + response/completion)."""
texts = []
for row in rows:
response = row.get("response") or row.get("completion") or ""
if response and len(str(response)) > 50:
response_only = _extract_response_only(str(response))
if response_only and len(response_only) > 50:
texts.append(response_only)
return texts
def extract_texts_qa(rows):
"""Extract from Q&A format (use answer as response)."""
texts = []
for row in rows:
answer = row.get("answer", "")
if answer and len(str(answer)) > 50:
response_only = _extract_response_only(str(answer))
if response_only and len(response_only) > 50:
texts.append(response_only)
return texts
def extract_texts_messages_array(rows):
"""Extract from messages array format."""
texts = []
for row in rows:
if isinstance(row, list):
for msg in row:
msg = _parse_msg(msg)
role = msg.get("role", "")
content = msg.get("content", "")
if role in ("assistant", "ai", "model") and content:
response_only = _extract_response_only(content)
if response_only and len(response_only) > 50:
texts.append(response_only)
return texts
def extract_texts_teichai_healer(rows):
"""Extract from TeichAI Healer-Alpha format (prompt + response fields)."""
texts = []
for row in rows:
response = row.get("response")
if response:
if isinstance(response, dict):
response = response.get("content") or response.get("text") or ""
if response and len(str(response)) > 50:
response_only = _extract_response_only(str(response))
if response_only and len(response_only) > 50:
texts.append(response_only)
return texts
def _get_dataset_size(dataset_id, load_kwargs):
"""Get dataset size without loading all data."""
try:
ds = load_dataset(dataset_id, split="train", streaming=True, **load_kwargs)
return ds.info.num_rows
except Exception:
pass
try:
import pandas as pd
url = f"https://huggingface.co/api/datasets/{dataset_id}/parquet/default/train/0.parquet"
df = pd.read_parquet(url)
return len(df)
except Exception:
return 0
def _streaming_download_with_progress(dataset_id, load_kwargs, progress_callback=None):
"""Download dataset using streaming with progress tracking."""
import pandas as pd
total_rows = _get_dataset_size(dataset_id, load_kwargs)
print(f"[PROGRESS] Dataset size: {total_rows} rows", flush=True)
if total_rows > 0 and progress_callback:
progress_callback(0, total_rows, "fetching_info")
print(f"[PROGRESS] Initial callback: 0/{total_rows}", flush=True)
try:
ds = load_dataset(dataset_id, split="train", streaming=True, **load_kwargs)
rows = []
for i, row in enumerate(tqdm(ds, desc="Downloading", unit="rows")):
rows.append(row)
if progress_callback and total_rows > 0:
progress_callback(i + 1, total_rows, "downloading")
if i % 100 == 0:
print(
f"[PROGRESS] Downloaded {i + 1}/{total_rows} ({100 * (i + 1) / total_rows:.1f}%)",
flush=True,
)
return rows, total_rows
except Exception as e:
print(f"[PROGRESS] Streaming failed: {e}", flush=True)
pass
try:
url = f"https://huggingface.co/api/datasets/{dataset_id}/parquet/default/train/0.parquet"
df = pd.read_parquet(url)
total = len(df)
if progress_callback:
progress_callback(0, total, "downloading")
rows = []
for i, row in enumerate(df.to_dict(orient="records")):
rows.append(row)
if progress_callback:
progress_callback(i + 1, total, "downloading")
return rows, total
except Exception as e:
raise e
def _load_sample_rows(dataset_id, sample_size, load_kwargs):
"""Load just a few rows for format detection."""
try:
ds = load_dataset(dataset_id, split="train", streaming=True, **load_kwargs)
return [next(iter(ds)) for _ in range(sample_size)]
except Exception:
pass
try:
import pandas as pd
url = f"https://huggingface.co/api/datasets/{dataset_id}/parquet/default/train/0.parquet"
df = pd.read_parquet(url)
return df.head(sample_size).to_dict(orient="records")
except Exception:
return []
def load_dataset_texts(
dataset_id,
max_samples=None,
sample_size=None,
progress_callback=None,
custom_format=None,
):
"""
Load a HuggingFace dataset and extract assistant response texts.
Returns: {
"texts": list of extracted texts,
"format": detected format name,
"format_info": format info dict,
"total_rows": total rows in dataset,
"supported": bool,
"error": error message if failed,
}
progress_callback: optional function(current, total, stage) -> None
stage can be: "fetching_info", "downloading", "extracting"
custom_format: optional custom format specification string
Examples:
- "column: response"
- "column: prompt, column: response"
- "pattern: user:, pattern: assistant:"
- "user:[startuser]assistant:[startassistant]"
"""
load_kwargs = {"token": HF_TOKEN} if HF_TOKEN else {}
rows = []
total_rows = 0
if sample_size:
total_rows = _get_dataset_size(dataset_id, load_kwargs)
if total_rows == 0:
return {
"texts": [],
"format": None,
"format_info": None,
"total_rows": 0,
"supported": False,
"error": "Dataset is empty",
}
rows = _load_sample_rows(dataset_id, sample_size, load_kwargs)
else:
if progress_callback:
try:
rows, total_rows = _streaming_download_with_progress(
dataset_id, load_kwargs, progress_callback
)
except Exception as e:
fallback_error = None
try:
import pandas as pd
url = f"https://huggingface.co/api/datasets/{dataset_id}/parquet/default/train/0.parquet"
df = pd.read_parquet(url)
total_rows = len(df)
if progress_callback:
progress_callback(0, total_rows, "downloading")
rows = []
for i, row in enumerate(df.to_dict(orient="records")):
rows.append(row)
if progress_callback:
progress_callback(i + 1, total_rows, "downloading")
except Exception as e2:
fallback_error = str(e2)
return {
"texts": [],
"format": None,
"format_info": None,
"total_rows": 0,
"supported": False,
"error": f"Failed to load: {e}. Parquet fallback also failed: {fallback_error}",
}
else:
try:
ds = load_dataset(dataset_id, split="train", **load_kwargs)
total_rows = len(ds)
rows = list(ds)
except Exception as e:
fallback_error = None
try:
import pandas as pd
url = f"https://huggingface.co/api/datasets/{dataset_id}/parquet/default/train/0.parquet"
df = pd.read_parquet(url)
total_rows = len(df)
rows = df.to_dict(orient="records")
except Exception as e2:
fallback_error = str(e2)
return {
"texts": [],
"format": None,
"format_info": None,
"total_rows": 0,
"supported": False,
"error": f"Failed to load: {e}. Parquet fallback also failed: {fallback_error}",
}
if not rows:
return {
"texts": [],
"format": None,
"format_info": None,
"total_rows": 0,
"supported": False,
"error": "Dataset is empty",
}
detect_rows = rows[:sample_size] if sample_size else rows
custom_format_spec = custom_format
if custom_format_spec and check_custom_format(detect_rows, custom_format_spec):
fmt_name = "custom"
fmt_info = {
"name": "Custom Format",
"description": f"Custom format: {custom_format_spec}",
"examples": [],
}
else:
fmt_name, fmt_info = detect_format(detect_rows, sample_size=sample_size or 10)
if fmt_name is None:
return {
"texts": [],
"format": None,
"format_info": None,
"total_rows": total_rows,
"supported": False,
"error": "Unknown dataset format. Supported formats: "
+ ", ".join(f["name"] for f in SUPPORTED_FORMATS.values()),
}
extractors = {
"teichai_healer": extract_texts_teichai_healer,
"teichai": extract_texts_conversations,
"conversations": extract_texts_conversations,
"chat": extract_texts_chat,
"text": lambda r: extract_texts_text_field(r, "text"),
"response": lambda r: extract_texts_text_field(r, "response")
or extract_texts_text_field(r, "output"),
"content": lambda r: extract_texts_text_field(r, "content"),
"messages": extract_texts_messages_array,
"sft": extract_texts_sft,
"qa": extract_texts_qa,
"combined": lambda r: (
extract_texts_text_field(r, "output")
or extract_texts_text_field(r, "outputs")
or extract_texts_text_field(r, "generated")
or extract_texts_text_field(r, "completion")
or extract_texts_text_field(r, "combined")
or extract_texts_text_field(r, "data")
or extract_texts_text_field(r, "example")
),
"completion": lambda r: extract_texts_text_field(r, "completion"),
"generations": lambda r: (
extract_texts_text_field(r, "generations")
or extract_texts_text_field(r, "generation")
),
"custom": lambda r: extract_texts_custom(r, custom_format_spec),
}
extractor = extractors.get(fmt_name)
texts = extractor(rows) if extractor else []
if max_samples and len(texts) > max_samples:
random.seed(42)
texts = random.sample(texts, max_samples)
return {
"texts": texts,
"format": fmt_name,
"format_info": fmt_info,
"total_rows": total_rows,
"supported": True,
"error": None,
}
def parse_custom_format_spec(spec):
"""
Parse custom format specification.
Supported formats:
- "column: <field_name>" - extract single field as text
- "column: <user_col>, column: <assistant_col>" - extract from two columns (user/assistant)
- "pattern: <start_marker>user<end_marker>, pattern: <start_marker>assistant<end_marker>" - use regex patterns
- "delimiter: <delim>" - use delimiter to split columns
Examples:
- "column: response"
- "column: prompt, column: response"
- "pattern: user:, pattern: assistant:"
- "user:[startuser]assistant:[startassistant]"
"""
if not spec:
return None
spec = spec.strip()
result = {
"type": None,
"user_field": None,
"assistant_field": None,
"user_pattern": None,
"assistant_pattern": None,
}
if spec.startswith("column:") or spec.startswith("col:"):
cols_spec = spec.replace("column:", "").replace("col:", "").strip()
if "," in cols_spec:
parts = [p.strip() for p in cols_spec.split(",")]
if len(parts) >= 2:
result["type"] = "two_column"
result["user_field"] = parts[0]
result["assistant_field"] = parts[1]
else:
result["type"] = "single_column"
result["assistant_field"] = cols_spec
return result
if spec.startswith("pattern:") or spec.startswith("regex:"):
patterns_spec = spec.replace("pattern:", "").replace("regex:", "").strip()
if "," in patterns_spec:
parts = [p.strip() for p in patterns_spec.split(",")]
if len(parts) >= 2:
result["type"] = "two_pattern"
result["user_pattern"] = parts[0]
result["assistant_pattern"] = parts[1]
else:
result["type"] = "single_pattern"
result["assistant_pattern"] = patterns_spec
return result
if "user:" in spec.lower() and "assistant:" in spec.lower():
import re
user_match = re.search(
r"user:\s*(\[.*?\]|(?:(?!\s+assistant:).)+)",
spec,
re.IGNORECASE | re.DOTALL,
)
assistant_match = re.search(
r"assistant:\s*(\[.*?\]|(?:(?:\s+user:|$).)+)",
spec,
re.IGNORECASE | re.DOTALL,
)
if user_match and assistant_match:
result["type"] = "two_pattern"
result["user_pattern"] = user_match.group(1).strip()
result["assistant_pattern"] = assistant_match.group(1).strip()
return result
if "[startuser]" in spec and "[startassistant]" in spec:
result["type"] = "two_pattern"
result["user_pattern"] = re.escape("[startuser]")
result["assistant_pattern"] = re.escape("[startassistant]")
return result
if "," in spec:
parts = [p.strip() for p in spec.split(",")]
if len(parts) >= 2:
result["type"] = "two_column"
result["user_field"] = parts[0]
result["assistant_field"] = parts[1]
return result
result["type"] = "single_column"
result["assistant_field"] = spec
return result
def extract_texts_custom(rows, format_spec):
"""Extract texts using custom format specification."""
parsed = parse_custom_format_spec(format_spec)
if not parsed or not parsed.get("type"):
return []
texts = []
if parsed["type"] == "single_column":
field = parsed["assistant_field"]
for row in rows:
content = row.get(field, "")
if content and len(str(content)) > 50:
response_only = _extract_response_only(str(content))
if response_only and len(response_only) > 50:
texts.append(response_only)
elif parsed["type"] == "two_column":
user_field = parsed.get("user_field")
assistant_field = parsed["assistant_field"]
for row in rows:
user_content = row.get(user_field, "") if user_field else ""
assistant_content = row.get(assistant_field, "")
if assistant_content and len(str(assistant_content)) > 50:
response_only = _extract_response_only(str(assistant_content))
if response_only and len(response_only) > 50:
texts.append(response_only)
elif parsed["type"] == "single_pattern":
pattern = parsed.get("assistant_pattern")
if pattern:
try:
regex = re.compile(pattern, re.DOTALL | re.IGNORECASE)
for row in rows:
row_str = str(row)
match = regex.search(row_str)
if match:
content = match.group(1) if match.groups() else match.group(0)
if content and len(content) > 50:
response_only = _extract_response_only(content)
if response_only and len(response_only) > 50:
texts.append(response_only)
except re.error:
pass
elif parsed["type"] == "two_pattern":
user_pattern = parsed.get("user_pattern")
assistant_pattern = parsed.get("assistant_pattern")
if assistant_pattern:
try:
user_regex = (
re.compile(user_pattern, re.DOTALL | re.IGNORECASE)
if user_pattern
else None
)
assistant_regex = re.compile(
assistant_pattern, re.DOTALL | re.IGNORECASE
)
for row in rows:
row_str = str(row)
match = assistant_regex.search(row_str)
if match:
content = match.group(1) if match.groups() else match.group(0)
if content and len(content) > 50:
response_only = _extract_response_only(content)
if response_only and len(response_only) > 50:
texts.append(response_only)
except re.error:
pass
return texts
def check_custom_format(rows, format_spec):
"""Check if custom format applies to the dataset."""
parsed = parse_custom_format_spec(format_spec)
if not parsed or not parsed.get("type"):
return False
if not rows:
return False
sample = rows[0]
if parsed["type"] == "single_column":
return parsed.get("assistant_field") in sample
if parsed["type"] == "two_column":
return parsed.get("assistant_field") in sample
if parsed["type"] in ("single_pattern", "two_pattern"):
pattern = parsed.get("assistant_pattern")
if pattern:
try:
regex = re.compile(pattern, re.DOTALL | re.IGNORECASE)
return regex.search(str(sample)) is not None
except re.error:
pass
return False
def get_supported_formats():
"""Return list of supported format info."""
return SUPPORTED_FORMATS