| """ |
| AIFinder Data Loader |
| Downloads and parses HuggingFace datasets, extracts assistant responses, |
| and labels them with is_ai, provider, and model. |
| """ |
|
|
| import os |
| import re |
| import time |
|
|
| from datasets import load_dataset |
| from tqdm import tqdm |
|
|
| from config import ( |
| DATASET_REGISTRY, |
| DEEPSEEK_AM_DATASETS, |
| MAX_SAMPLES_PER_PROVIDER, |
| ) |
|
|
| HF_TOKEN = os.environ.get("HF_TOKEN") |
|
|
|
|
| def _parse_msg(msg): |
| """Parse a message that may be a dict or a JSON string.""" |
| if isinstance(msg, dict): |
| return msg |
| if isinstance(msg, str): |
| try: |
| import json as _json |
|
|
| parsed = _json.loads(msg) |
| if isinstance(parsed, dict): |
| return parsed |
| except (ValueError, Exception): |
| pass |
| return {} |
|
|
|
|
| def _extract_response_only(content): |
| """Extract only the final response, stripping CoT blocks. |
| Returns only the text after </think> or </thinking> if present, |
| otherwise returns the full content. |
| """ |
| if not content: |
| return "" |
| think_match = re.search(r"</?think(?:ing)?>(.*)$", content, re.DOTALL) |
| if think_match: |
| response = think_match.group(1).strip() |
| if response: |
| return response |
| return content |
|
|
|
|
| def _extract_assistant_texts_from_conversations(rows): |
| """Extract individual assistant messages from conversation datasets. |
| Returns one text per assistant turn (not concatenated) for cleaner samples. |
| Only extracts the response portion (after </think> if present). |
| """ |
| texts = [] |
| for row in rows: |
| convos = row.get("conversations") |
| if convos is None or (hasattr(convos, "__len__") and len(convos) == 0): |
| convos = row.get("messages") |
| if convos is None or (hasattr(convos, "__len__") and len(convos) == 0): |
| convos = [] |
| for msg in convos: |
| msg = _parse_msg(msg) |
| role = msg.get("role", "") |
| content = msg.get("content", "") |
| if role in ("assistant", "gpt", "model") and content: |
| response_only = _extract_response_only(content) |
| if response_only: |
| texts.append(response_only) |
| return texts |
|
|
|
|
| def _extract_from_am_dataset(row): |
| """Extract individual assistant texts from a-m-team format. |
| Only extracts the response portion (after </think> if present). |
| """ |
| messages = row.get("messages") or row.get("conversations") or [] |
| texts = [] |
| for msg in messages: |
| role = msg.get("role", "") if isinstance(msg, dict) else "" |
| content = msg.get("content", "") if isinstance(msg, dict) else "" |
| if role == "assistant" and content: |
| response_only = _extract_response_only(content) |
| if response_only: |
| texts.append(response_only) |
| return texts |
|
|
|
|
| def load_teichai_dataset(dataset_id, provider, model_name, kwargs): |
| """Load a single conversation-format dataset and return (texts, providers, models).""" |
| max_samples = kwargs.get("max_samples") |
| load_kwargs = {} |
| if "name" in kwargs: |
| load_kwargs["name"] = kwargs["name"] |
|
|
| try: |
| ds = load_dataset(dataset_id, split="train", token=HF_TOKEN, **load_kwargs) |
| rows = list(ds) |
| except Exception as e: |
| |
| try: |
| import pandas as pd |
|
|
| url = f"https://huggingface.co/api/datasets/{dataset_id}/parquet/default/train/0.parquet" |
| df = pd.read_parquet(url) |
| rows = df.to_dict(orient="records") |
| except Exception as e2: |
| print(f" [SKIP] {dataset_id}: {e} / parquet fallback: {e2}") |
| return [], [], [] |
|
|
| if max_samples and len(rows) > max_samples: |
| import random |
|
|
| random.seed(42) |
| rows = random.sample(rows, max_samples) |
|
|
| texts = _extract_assistant_texts_from_conversations(rows) |
|
|
| |
| filtered = [(t, provider, model_name) for t in texts if len(t) > 50] |
| if not filtered: |
| print(f" [SKIP] {dataset_id}: no valid texts extracted") |
| return [], [], [] |
|
|
| t, p, m = zip(*filtered) |
| return list(t), list(p), list(m) |
|
|
|
|
| def load_am_deepseek_dataset(dataset_id, provider, model_name, kwargs): |
| """Load a-m-team DeepSeek dataset.""" |
| max_samples = kwargs.get("max_samples") |
| load_kwargs = {} |
| if "name" in kwargs: |
| load_kwargs["name"] = kwargs["name"] |
|
|
| try: |
| ds = load_dataset(dataset_id, split="train", token=HF_TOKEN, **load_kwargs) |
| except Exception: |
| try: |
| ds = load_dataset( |
| dataset_id, split="train", streaming=True, token=HF_TOKEN, **load_kwargs |
| ) |
| rows = [] |
| for row in ds: |
| rows.append(row) |
| if max_samples and len(rows) >= max_samples: |
| break |
| except Exception as e2: |
| print(f" [SKIP] {dataset_id}: {e2}") |
| return [], [], [] |
| else: |
| rows = list(ds) |
| if max_samples and len(rows) > max_samples: |
| rows = rows[:max_samples] |
|
|
| texts = [] |
| for row in rows: |
| for text in _extract_from_am_dataset(row): |
| if len(text) > 50: |
| texts.append(text) |
|
|
| providers = [provider] * len(texts) |
| models = [model_name] * len(texts) |
| return texts, providers, models |
|
|
|
|
| def load_all_data(): |
| """Load all datasets and return combined lists. |
| |
| Returns: |
| texts: list of str |
| providers: list of str |
| models: list of str |
| is_ai: list of int (1=AI, 0=Human) |
| """ |
| all_texts = [] |
| all_providers = [] |
| all_models = [] |
|
|
| |
| print("Loading TeichAI datasets...") |
| for dataset_id, provider, model_name, kwargs in tqdm( |
| DATASET_REGISTRY, desc="TeichAI" |
| ): |
| t0 = time.time() |
| texts, providers, models = load_teichai_dataset( |
| dataset_id, provider, model_name, kwargs |
| ) |
| elapsed = time.time() - t0 |
| all_texts.extend(texts) |
| all_providers.extend(providers) |
| all_models.extend(models) |
| print(f" {dataset_id}: {len(texts)} samples ({elapsed:.1f}s)") |
|
|
| |
| print("\nLoading DeepSeek (a-m-team) datasets...") |
| for dataset_id, provider, model_name, kwargs in tqdm( |
| DEEPSEEK_AM_DATASETS, desc="DeepSeek-AM" |
| ): |
| t0 = time.time() |
| texts, providers, models = load_am_deepseek_dataset( |
| dataset_id, provider, model_name, kwargs |
| ) |
| elapsed = time.time() - t0 |
| all_texts.extend(texts) |
| all_providers.extend(providers) |
| all_models.extend(models) |
| print(f" {dataset_id}: {len(texts)} samples ({elapsed:.1f}s)") |
|
|
| |
| import hashlib |
| import random as _rng |
|
|
| _rng.seed(42) |
|
|
| seen = set() |
| dedup_texts, dedup_providers, dedup_models = [], [], [] |
| for t, p, m in zip(all_texts, all_providers, all_models): |
| h = hashlib.md5(t.strip().lower().encode()).hexdigest() |
| if h not in seen: |
| seen.add(h) |
| dedup_texts.append(t) |
| dedup_providers.append(p) |
| dedup_models.append(m) |
|
|
| n_dupes = len(all_texts) - len(dedup_texts) |
| if n_dupes > 0: |
| print(f"\n Removed {n_dupes} duplicate samples") |
|
|
| |
| from collections import defaultdict |
|
|
| provider_indices = defaultdict(list) |
| for i, p in enumerate(dedup_providers): |
| provider_indices[p].append(i) |
|
|
| |
| keep_indices = [] |
| for p, idxs in provider_indices.items(): |
| _rng.shuffle(idxs) |
| n_sample = min(len(idxs), MAX_SAMPLES_PER_PROVIDER) |
| idxs = idxs[:n_sample] |
| print(f" Sampled {p}: {len(idxs)} samples") |
| keep_indices.extend(idxs) |
| keep_indices.sort() |
|
|
| all_texts = [dedup_texts[i] for i in keep_indices] |
| all_providers = [dedup_providers[i] for i in keep_indices] |
| all_models = [dedup_models[i] for i in keep_indices] |
|
|
| |
| is_ai = [1] * len(all_texts) |
|
|
| print(f"\n=== Total: {len(all_texts)} samples ===") |
| |
| from collections import Counter |
|
|
| prov_counts = Counter(all_providers) |
| for p, c in sorted(prov_counts.items(), key=lambda x: -x[1]): |
| print(f" {p}: {c}") |
|
|
| return all_texts, all_providers, all_models, is_ai |
|
|
|
|
| if __name__ == "__main__": |
| texts, providers, models, is_ai = load_all_data() |
|
|