File size: 7,450 Bytes
cdf485e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
"""
Dataset Preparation Script.

Downloads HH-RLHF, UltraFeedback, and Stanford SHP from Hugging Face
and converts them into the format expected by PreferenceLab.

Usage:
    python scripts/prepare_datasets.py
    python scripts/prepare_datasets.py --samples 200
"""

import argparse
import json
import random
from pathlib import Path

DATA_DIR = Path(__file__).parent.parent / "data"
DATA_DIR.mkdir(exist_ok=True)


def prepare_pairwise(n_samples: int = 100):
    """Download Anthropic HH-RLHF and convert to pairwise format."""
    print(f"[1/3] Preparing pairwise data (HH-RLHF, {n_samples} samples)...")
    try:
        from datasets import load_dataset
        ds = load_dataset("Anthropic/hh-rlhf", split="train", streaming=True)
        records = []
        for i, ex in enumerate(ds):
            if i >= n_samples:
                break
            # chosen = better response, rejected = worse
            chosen = ex.get("chosen", "")
            rejected = ex.get("rejected", "")
            # Extract the last human turn as prompt
            lines = chosen.split("\n\nAssistant:")
            if len(lines) >= 2:
                prompt_block = lines[0].replace("Human:", "").strip()
                resp_a = lines[-1].strip()
            else:
                prompt_block = chosen[:100]
                resp_a = chosen

            rej_lines = rejected.split("\n\nAssistant:")
            resp_b = rej_lines[-1].strip() if len(rej_lines) >= 2 else rejected

            # Randomly swap A/B to avoid position bias, track gold
            if random.random() < 0.5:
                records.append({
                    "prompt": prompt_block,
                    "response_a": resp_a,
                    "response_b": resp_b,
                    "gold_label": "A",
                    "source": "hh-rlhf",
                })
            else:
                records.append({
                    "prompt": prompt_block,
                    "response_a": resp_b,
                    "response_b": resp_a,
                    "gold_label": "B",
                    "source": "hh-rlhf",
                })

        out = DATA_DIR / "pairwise_data.json"
        with open(out, "w") as f:
            json.dump(records, f, indent=2)
        print(f"  βœ“ Saved {len(records)} pairwise examples β†’ {out}")
    except Exception as e:
        print(f"  βœ— Failed: {e} β€” synthetic fallback will be used")


def prepare_likert(n_samples: int = 100):
    """Download UltraFeedback and convert to likert format."""
    print(f"[2/3] Preparing likert data (UltraFeedback, {n_samples} samples)...")
    try:
        from datasets import load_dataset
        ds = load_dataset("openbmb/UltraFeedback", split="train", streaming=True)
        records = []
        for i, ex in enumerate(ds):
            if i >= n_samples:
                break
            instr = ex.get("instruction", "")
            completions = ex.get("completions", [])
            if not completions:
                continue
            comp = completions[0]
            response = comp.get("response", "")
            annots = comp.get("annotations", {})

            def extract_score(key, default=3):
                val = annots.get(key, {})
                if isinstance(val, dict):
                    raw = val.get("Rating", default)
                elif isinstance(val, (int, float)):
                    raw = val
                else:
                    raw = default
                # UltraFeedback uses 1-5 scale
                try:
                    return max(1, min(5, int(raw)))
                except Exception:
                    return default

            records.append({
                "prompt": instr,
                "response": response,
                "rubric": (
                    "Score on 4 axes (1=worst, 5=best): helpfulness, honesty, "
                    "harmlessness, instruction_following."
                ),
                "gold_scores": {
                    "helpfulness": extract_score("instruction_following"),
                    "honesty": extract_score("honesty"),
                    "harmlessness": extract_score("truthfulness", 4),
                    "instruction_following": extract_score("instruction_following"),
                },
                "source": "ultrafeedback",
            })

        out = DATA_DIR / "likert_data.json"
        with open(out, "w") as f:
            json.dump(records, f, indent=2)
        print(f"  βœ“ Saved {len(records)} likert examples β†’ {out}")
    except Exception as e:
        print(f"  βœ— Failed: {e} β€” synthetic fallback will be used")


def prepare_consistency(n_samples: int = 60):
    """Build 4-way ranking examples from Stanford SHP."""
    print(f"[3/3] Preparing consistency data (Stanford SHP, {n_samples} samples)...")
    try:
        from datasets import load_dataset
        ds = load_dataset("stanfordnlp/SHP", split="train", streaming=True)

        # Group by post_id to collect multiple responses per prompt
        grouped: dict[str, dict] = {}
        for ex in ds:
            pid = ex.get("post_id", "")
            if pid not in grouped:
                grouped[pid] = {
                    "prompt": ex.get("history", ""),
                    "responses": [],
                }
            grouped[pid]["responses"].append({
                "text": ex.get("human_ref_A", "") or ex.get("human_ref_B", ""),
                "score": ex.get("score_ratio", 1.0),
            })
            if len(grouped) >= n_samples * 3:
                break

        records = []
        for pid, data in grouped.items():
            resps = data["responses"]
            if len(resps) < 4:
                continue
            # Sort by score descending = gold ranking
            resps_sorted = sorted(resps[:4], key=lambda x: x["score"], reverse=True)
            labels = ["A", "B", "C", "D"]
            # Shuffle display order (not gold order)
            shuffled = resps_sorted[:]
            random.shuffle(shuffled)
            id_map = {labels[i]: shuffled[i] for i in range(4)}
            gold_ranking = sorted(labels, key=lambda l: resps_sorted.index(id_map[l]))

            records.append({
                "prompt": data["prompt"][:500],
                "response_a": id_map["A"]["text"][:400],
                "response_b": id_map["B"]["text"][:400],
                "response_c": id_map["C"]["text"][:400],
                "response_d": id_map["D"]["text"][:400],
                "gold_ranking": gold_ranking,
                "source": "stanford-shp",
            })
            if len(records) >= n_samples:
                break

        out = DATA_DIR / "consistency_data.json"
        with open(out, "w") as f:
            json.dump(records, f, indent=2)
        print(f"  βœ“ Saved {len(records)} consistency examples β†’ {out}")
    except Exception as e:
        print(f"  βœ— Failed: {e} β€” synthetic fallback will be used")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--samples", type=int, default=100,
                        help="Number of samples per task (default: 100)")
    args = parser.parse_args()

    print("=" * 50)
    print("PreferenceLab Dataset Preparation")
    print("=" * 50)
    prepare_pairwise(args.samples)
    prepare_likert(args.samples)
    prepare_consistency(args.samples // 2)
    print("\nβœ“ Done. Run inference.py to test.")


if __name__ == "__main__":
    main()