preference-lab / scripts /prepare_datasets.py
Sibam
PreferenceLab OpenEnv environment for RLHF preference simulation
cdf485e
"""
Dataset Preparation Script.
Downloads HH-RLHF, UltraFeedback, and Stanford SHP from Hugging Face
and converts them into the format expected by PreferenceLab.
Usage:
python scripts/prepare_datasets.py
python scripts/prepare_datasets.py --samples 200
"""
import argparse
import json
import random
from pathlib import Path
DATA_DIR = Path(__file__).parent.parent / "data"
DATA_DIR.mkdir(exist_ok=True)
def prepare_pairwise(n_samples: int = 100):
"""Download Anthropic HH-RLHF and convert to pairwise format."""
print(f"[1/3] Preparing pairwise data (HH-RLHF, {n_samples} samples)...")
try:
from datasets import load_dataset
ds = load_dataset("Anthropic/hh-rlhf", split="train", streaming=True)
records = []
for i, ex in enumerate(ds):
if i >= n_samples:
break
# chosen = better response, rejected = worse
chosen = ex.get("chosen", "")
rejected = ex.get("rejected", "")
# Extract the last human turn as prompt
lines = chosen.split("\n\nAssistant:")
if len(lines) >= 2:
prompt_block = lines[0].replace("Human:", "").strip()
resp_a = lines[-1].strip()
else:
prompt_block = chosen[:100]
resp_a = chosen
rej_lines = rejected.split("\n\nAssistant:")
resp_b = rej_lines[-1].strip() if len(rej_lines) >= 2 else rejected
# Randomly swap A/B to avoid position bias, track gold
if random.random() < 0.5:
records.append({
"prompt": prompt_block,
"response_a": resp_a,
"response_b": resp_b,
"gold_label": "A",
"source": "hh-rlhf",
})
else:
records.append({
"prompt": prompt_block,
"response_a": resp_b,
"response_b": resp_a,
"gold_label": "B",
"source": "hh-rlhf",
})
out = DATA_DIR / "pairwise_data.json"
with open(out, "w") as f:
json.dump(records, f, indent=2)
print(f" βœ“ Saved {len(records)} pairwise examples β†’ {out}")
except Exception as e:
print(f" βœ— Failed: {e} β€” synthetic fallback will be used")
def prepare_likert(n_samples: int = 100):
"""Download UltraFeedback and convert to likert format."""
print(f"[2/3] Preparing likert data (UltraFeedback, {n_samples} samples)...")
try:
from datasets import load_dataset
ds = load_dataset("openbmb/UltraFeedback", split="train", streaming=True)
records = []
for i, ex in enumerate(ds):
if i >= n_samples:
break
instr = ex.get("instruction", "")
completions = ex.get("completions", [])
if not completions:
continue
comp = completions[0]
response = comp.get("response", "")
annots = comp.get("annotations", {})
def extract_score(key, default=3):
val = annots.get(key, {})
if isinstance(val, dict):
raw = val.get("Rating", default)
elif isinstance(val, (int, float)):
raw = val
else:
raw = default
# UltraFeedback uses 1-5 scale
try:
return max(1, min(5, int(raw)))
except Exception:
return default
records.append({
"prompt": instr,
"response": response,
"rubric": (
"Score on 4 axes (1=worst, 5=best): helpfulness, honesty, "
"harmlessness, instruction_following."
),
"gold_scores": {
"helpfulness": extract_score("instruction_following"),
"honesty": extract_score("honesty"),
"harmlessness": extract_score("truthfulness", 4),
"instruction_following": extract_score("instruction_following"),
},
"source": "ultrafeedback",
})
out = DATA_DIR / "likert_data.json"
with open(out, "w") as f:
json.dump(records, f, indent=2)
print(f" βœ“ Saved {len(records)} likert examples β†’ {out}")
except Exception as e:
print(f" βœ— Failed: {e} β€” synthetic fallback will be used")
def prepare_consistency(n_samples: int = 60):
"""Build 4-way ranking examples from Stanford SHP."""
print(f"[3/3] Preparing consistency data (Stanford SHP, {n_samples} samples)...")
try:
from datasets import load_dataset
ds = load_dataset("stanfordnlp/SHP", split="train", streaming=True)
# Group by post_id to collect multiple responses per prompt
grouped: dict[str, dict] = {}
for ex in ds:
pid = ex.get("post_id", "")
if pid not in grouped:
grouped[pid] = {
"prompt": ex.get("history", ""),
"responses": [],
}
grouped[pid]["responses"].append({
"text": ex.get("human_ref_A", "") or ex.get("human_ref_B", ""),
"score": ex.get("score_ratio", 1.0),
})
if len(grouped) >= n_samples * 3:
break
records = []
for pid, data in grouped.items():
resps = data["responses"]
if len(resps) < 4:
continue
# Sort by score descending = gold ranking
resps_sorted = sorted(resps[:4], key=lambda x: x["score"], reverse=True)
labels = ["A", "B", "C", "D"]
# Shuffle display order (not gold order)
shuffled = resps_sorted[:]
random.shuffle(shuffled)
id_map = {labels[i]: shuffled[i] for i in range(4)}
gold_ranking = sorted(labels, key=lambda l: resps_sorted.index(id_map[l]))
records.append({
"prompt": data["prompt"][:500],
"response_a": id_map["A"]["text"][:400],
"response_b": id_map["B"]["text"][:400],
"response_c": id_map["C"]["text"][:400],
"response_d": id_map["D"]["text"][:400],
"gold_ranking": gold_ranking,
"source": "stanford-shp",
})
if len(records) >= n_samples:
break
out = DATA_DIR / "consistency_data.json"
with open(out, "w") as f:
json.dump(records, f, indent=2)
print(f" βœ“ Saved {len(records)} consistency examples β†’ {out}")
except Exception as e:
print(f" βœ— Failed: {e} β€” synthetic fallback will be used")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--samples", type=int, default=100,
help="Number of samples per task (default: 100)")
args = parser.parse_args()
print("=" * 50)
print("PreferenceLab Dataset Preparation")
print("=" * 50)
prepare_pairwise(args.samples)
prepare_likert(args.samples)
prepare_consistency(args.samples // 2)
print("\nβœ“ Done. Run inference.py to test.")
if __name__ == "__main__":
main()