""" AIFinder Data Loader Downloads and parses HuggingFace datasets, extracts assistant responses, and labels them with is_ai, provider, and model. """ import os import re import time from datasets import load_dataset from tqdm import tqdm from config import ( DATASET_REGISTRY, DEEPSEEK_AM_DATASETS, MAX_SAMPLES_PER_PROVIDER, ) HF_TOKEN = os.environ.get("HF_TOKEN") def _parse_msg(msg): """Parse a message that may be a dict or a JSON string.""" if isinstance(msg, dict): return msg if isinstance(msg, str): try: import json as _json parsed = _json.loads(msg) if isinstance(parsed, dict): return parsed except (ValueError, Exception): pass return {} def _extract_response_only(content): """Extract only the final response, stripping CoT blocks. Returns only the text after or if present, otherwise returns the full content. """ if not content: return "" think_match = re.search(r"(.*)$", content, re.DOTALL) if think_match: response = think_match.group(1).strip() if response: return response return content def _extract_assistant_texts_from_conversations(rows): """Extract individual assistant messages from conversation datasets. Returns one text per assistant turn (not concatenated) for cleaner samples. Only extracts the response portion (after if present). """ texts = [] for row in rows: convos = row.get("conversations") if convos is None or (hasattr(convos, "__len__") and len(convos) == 0): convos = row.get("messages") if convos is None or (hasattr(convos, "__len__") and len(convos) == 0): convos = [] for msg in convos: msg = _parse_msg(msg) role = msg.get("role", "") content = msg.get("content", "") if role in ("assistant", "gpt", "model") and content: response_only = _extract_response_only(content) if response_only: texts.append(response_only) return texts def _extract_from_am_dataset(row): """Extract individual assistant texts from a-m-team format. Only extracts the response portion (after if present). """ messages = row.get("messages") or row.get("conversations") or [] texts = [] for msg in messages: role = msg.get("role", "") if isinstance(msg, dict) else "" content = msg.get("content", "") if isinstance(msg, dict) else "" if role == "assistant" and content: response_only = _extract_response_only(content) if response_only: texts.append(response_only) return texts def load_teichai_dataset(dataset_id, provider, model_name, kwargs): """Load a single conversation-format dataset and return (texts, providers, models).""" max_samples = kwargs.get("max_samples") load_kwargs = {} if "name" in kwargs: load_kwargs["name"] = kwargs["name"] try: ds = load_dataset(dataset_id, split="train", token=HF_TOKEN, **load_kwargs) rows = list(ds) except Exception as e: # Fallback: load from auto-converted parquet via HF API try: import pandas as pd url = f"https://huggingface.co/api/datasets/{dataset_id}/parquet/default/train/0.parquet" df = pd.read_parquet(url) rows = df.to_dict(orient="records") except Exception as e2: print(f" [SKIP] {dataset_id}: {e} / parquet fallback: {e2}") return [], [], [] if max_samples and len(rows) > max_samples: import random random.seed(42) rows = random.sample(rows, max_samples) texts = _extract_assistant_texts_from_conversations(rows) # Filter out empty/too-short texts filtered = [(t, provider, model_name) for t in texts if len(t) > 50] if not filtered: print(f" [SKIP] {dataset_id}: no valid texts extracted") return [], [], [] t, p, m = zip(*filtered) return list(t), list(p), list(m) def load_am_deepseek_dataset(dataset_id, provider, model_name, kwargs): """Load a-m-team DeepSeek dataset.""" max_samples = kwargs.get("max_samples") load_kwargs = {} if "name" in kwargs: load_kwargs["name"] = kwargs["name"] try: ds = load_dataset(dataset_id, split="train", token=HF_TOKEN, **load_kwargs) except Exception: try: ds = load_dataset( dataset_id, split="train", streaming=True, token=HF_TOKEN, **load_kwargs ) rows = [] for row in ds: rows.append(row) if max_samples and len(rows) >= max_samples: break except Exception as e2: print(f" [SKIP] {dataset_id}: {e2}") return [], [], [] else: rows = list(ds) if max_samples and len(rows) > max_samples: rows = rows[:max_samples] texts = [] for row in rows: for text in _extract_from_am_dataset(row): if len(text) > 50: texts.append(text) providers = [provider] * len(texts) models = [model_name] * len(texts) return texts, providers, models def load_all_data(): """Load all datasets and return combined lists. Returns: texts: list of str providers: list of str models: list of str is_ai: list of int (1=AI, 0=Human) """ all_texts = [] all_providers = [] all_models = [] # TeichAI datasets print("Loading TeichAI datasets...") for dataset_id, provider, model_name, kwargs in tqdm( DATASET_REGISTRY, desc="TeichAI" ): t0 = time.time() texts, providers, models = load_teichai_dataset( dataset_id, provider, model_name, kwargs ) elapsed = time.time() - t0 all_texts.extend(texts) all_providers.extend(providers) all_models.extend(models) print(f" {dataset_id}: {len(texts)} samples ({elapsed:.1f}s)") # DeepSeek a-m-team datasets print("\nLoading DeepSeek (a-m-team) datasets...") for dataset_id, provider, model_name, kwargs in tqdm( DEEPSEEK_AM_DATASETS, desc="DeepSeek-AM" ): t0 = time.time() texts, providers, models = load_am_deepseek_dataset( dataset_id, provider, model_name, kwargs ) elapsed = time.time() - t0 all_texts.extend(texts) all_providers.extend(providers) all_models.extend(models) print(f" {dataset_id}: {len(texts)} samples ({elapsed:.1f}s)") # Deduplicate by text hash import hashlib import random as _rng _rng.seed(42) seen = set() dedup_texts, dedup_providers, dedup_models = [], [], [] for t, p, m in zip(all_texts, all_providers, all_models): h = hashlib.md5(t.strip().lower().encode()).hexdigest() if h not in seen: seen.add(h) dedup_texts.append(t) dedup_providers.append(p) dedup_models.append(m) n_dupes = len(all_texts) - len(dedup_texts) if n_dupes > 0: print(f"\n Removed {n_dupes} duplicate samples") # Equal samples per provider from collections import defaultdict provider_indices = defaultdict(list) for i, p in enumerate(dedup_providers): provider_indices[p].append(i) # Use min of available or max allowed keep_indices = [] for p, idxs in provider_indices.items(): _rng.shuffle(idxs) n_sample = min(len(idxs), MAX_SAMPLES_PER_PROVIDER) idxs = idxs[:n_sample] print(f" Sampled {p}: {len(idxs)} samples") keep_indices.extend(idxs) keep_indices.sort() all_texts = [dedup_texts[i] for i in keep_indices] all_providers = [dedup_providers[i] for i in keep_indices] all_models = [dedup_models[i] for i in keep_indices] # Build is_ai labels (all AI) is_ai = [1] * len(all_texts) print(f"\n=== Total: {len(all_texts)} samples ===") # Print per-provider counts from collections import Counter prov_counts = Counter(all_providers) for p, c in sorted(prov_counts.items(), key=lambda x: -x[1]): print(f" {p}: {c}") return all_texts, all_providers, all_models, is_ai if __name__ == "__main__": texts, providers, models, is_ai = load_all_data()