Spaces:
Running
Running
| """ | |
| Dataset Preparation Script. | |
| Downloads HH-RLHF, UltraFeedback, and Stanford SHP from Hugging Face | |
| and converts them into the format expected by PreferenceLab. | |
| Usage: | |
| python scripts/prepare_datasets.py | |
| python scripts/prepare_datasets.py --samples 200 | |
| """ | |
| import argparse | |
| import json | |
| import random | |
| from pathlib import Path | |
| DATA_DIR = Path(__file__).parent.parent / "data" | |
| DATA_DIR.mkdir(exist_ok=True) | |
| def prepare_pairwise(n_samples: int = 100): | |
| """Download Anthropic HH-RLHF and convert to pairwise format.""" | |
| print(f"[1/3] Preparing pairwise data (HH-RLHF, {n_samples} samples)...") | |
| try: | |
| from datasets import load_dataset | |
| ds = load_dataset("Anthropic/hh-rlhf", split="train", streaming=True) | |
| records = [] | |
| for i, ex in enumerate(ds): | |
| if i >= n_samples: | |
| break | |
| # chosen = better response, rejected = worse | |
| chosen = ex.get("chosen", "") | |
| rejected = ex.get("rejected", "") | |
| # Extract the last human turn as prompt | |
| lines = chosen.split("\n\nAssistant:") | |
| if len(lines) >= 2: | |
| prompt_block = lines[0].replace("Human:", "").strip() | |
| resp_a = lines[-1].strip() | |
| else: | |
| prompt_block = chosen[:100] | |
| resp_a = chosen | |
| rej_lines = rejected.split("\n\nAssistant:") | |
| resp_b = rej_lines[-1].strip() if len(rej_lines) >= 2 else rejected | |
| # Randomly swap A/B to avoid position bias, track gold | |
| if random.random() < 0.5: | |
| records.append({ | |
| "prompt": prompt_block, | |
| "response_a": resp_a, | |
| "response_b": resp_b, | |
| "gold_label": "A", | |
| "source": "hh-rlhf", | |
| }) | |
| else: | |
| records.append({ | |
| "prompt": prompt_block, | |
| "response_a": resp_b, | |
| "response_b": resp_a, | |
| "gold_label": "B", | |
| "source": "hh-rlhf", | |
| }) | |
| out = DATA_DIR / "pairwise_data.json" | |
| with open(out, "w") as f: | |
| json.dump(records, f, indent=2) | |
| print(f" β Saved {len(records)} pairwise examples β {out}") | |
| except Exception as e: | |
| print(f" β Failed: {e} β synthetic fallback will be used") | |
| def prepare_likert(n_samples: int = 100): | |
| """Download UltraFeedback and convert to likert format.""" | |
| print(f"[2/3] Preparing likert data (UltraFeedback, {n_samples} samples)...") | |
| try: | |
| from datasets import load_dataset | |
| ds = load_dataset("openbmb/UltraFeedback", split="train", streaming=True) | |
| records = [] | |
| for i, ex in enumerate(ds): | |
| if i >= n_samples: | |
| break | |
| instr = ex.get("instruction", "") | |
| completions = ex.get("completions", []) | |
| if not completions: | |
| continue | |
| comp = completions[0] | |
| response = comp.get("response", "") | |
| annots = comp.get("annotations", {}) | |
| def extract_score(key, default=3): | |
| val = annots.get(key, {}) | |
| if isinstance(val, dict): | |
| raw = val.get("Rating", default) | |
| elif isinstance(val, (int, float)): | |
| raw = val | |
| else: | |
| raw = default | |
| # UltraFeedback uses 1-5 scale | |
| try: | |
| return max(1, min(5, int(raw))) | |
| except Exception: | |
| return default | |
| records.append({ | |
| "prompt": instr, | |
| "response": response, | |
| "rubric": ( | |
| "Score on 4 axes (1=worst, 5=best): helpfulness, honesty, " | |
| "harmlessness, instruction_following." | |
| ), | |
| "gold_scores": { | |
| "helpfulness": extract_score("instruction_following"), | |
| "honesty": extract_score("honesty"), | |
| "harmlessness": extract_score("truthfulness", 4), | |
| "instruction_following": extract_score("instruction_following"), | |
| }, | |
| "source": "ultrafeedback", | |
| }) | |
| out = DATA_DIR / "likert_data.json" | |
| with open(out, "w") as f: | |
| json.dump(records, f, indent=2) | |
| print(f" β Saved {len(records)} likert examples β {out}") | |
| except Exception as e: | |
| print(f" β Failed: {e} β synthetic fallback will be used") | |
| def prepare_consistency(n_samples: int = 60): | |
| """Build 4-way ranking examples from Stanford SHP.""" | |
| print(f"[3/3] Preparing consistency data (Stanford SHP, {n_samples} samples)...") | |
| try: | |
| from datasets import load_dataset | |
| ds = load_dataset("stanfordnlp/SHP", split="train", streaming=True) | |
| # Group by post_id to collect multiple responses per prompt | |
| grouped: dict[str, dict] = {} | |
| for ex in ds: | |
| pid = ex.get("post_id", "") | |
| if pid not in grouped: | |
| grouped[pid] = { | |
| "prompt": ex.get("history", ""), | |
| "responses": [], | |
| } | |
| grouped[pid]["responses"].append({ | |
| "text": ex.get("human_ref_A", "") or ex.get("human_ref_B", ""), | |
| "score": ex.get("score_ratio", 1.0), | |
| }) | |
| if len(grouped) >= n_samples * 3: | |
| break | |
| records = [] | |
| for pid, data in grouped.items(): | |
| resps = data["responses"] | |
| if len(resps) < 4: | |
| continue | |
| # Sort by score descending = gold ranking | |
| resps_sorted = sorted(resps[:4], key=lambda x: x["score"], reverse=True) | |
| labels = ["A", "B", "C", "D"] | |
| # Shuffle display order (not gold order) | |
| shuffled = resps_sorted[:] | |
| random.shuffle(shuffled) | |
| id_map = {labels[i]: shuffled[i] for i in range(4)} | |
| gold_ranking = sorted(labels, key=lambda l: resps_sorted.index(id_map[l])) | |
| records.append({ | |
| "prompt": data["prompt"][:500], | |
| "response_a": id_map["A"]["text"][:400], | |
| "response_b": id_map["B"]["text"][:400], | |
| "response_c": id_map["C"]["text"][:400], | |
| "response_d": id_map["D"]["text"][:400], | |
| "gold_ranking": gold_ranking, | |
| "source": "stanford-shp", | |
| }) | |
| if len(records) >= n_samples: | |
| break | |
| out = DATA_DIR / "consistency_data.json" | |
| with open(out, "w") as f: | |
| json.dump(records, f, indent=2) | |
| print(f" β Saved {len(records)} consistency examples β {out}") | |
| except Exception as e: | |
| print(f" β Failed: {e} β synthetic fallback will be used") | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--samples", type=int, default=100, | |
| help="Number of samples per task (default: 100)") | |
| args = parser.parse_args() | |
| print("=" * 50) | |
| print("PreferenceLab Dataset Preparation") | |
| print("=" * 50) | |
| prepare_pairwise(args.samples) | |
| prepare_likert(args.samples) | |
| prepare_consistency(args.samples // 2) | |
| print("\nβ Done. Run inference.py to test.") | |
| if __name__ == "__main__": | |
| main() | |