| --- |
| language: |
| - en |
| pipeline_tag: text-generation |
| tags: |
| - pytorch |
| - Mistral |
| --- |
| |
| ## Model Details |
|
|
| We employ **Mistral-Base(7B)** as one of the base models to evaluate our proposed **Reward-Driven Selective Penalization for Preference Alignment Optimization (RSPO)** method. The model is trained for **one epoch** on the **UltraFeedback Binarized dataset** using **(RSPO)** method. |
|
|
| ## How to use |
|
|
| #### Transformers AutoModelForCausalLM |
|
|
| ```python |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
| import torch |
| |
| model_id = "li11111/Mistral-7B-Base-RSPO" |
| |
| tokenizer = AutoTokenizer.from_pretrained(model_id) |
| model = AutoModelForCausalLM.from_pretrained( |
| model_id, |
| torch_dtype=torch.bfloat16, |
| device_map="auto", |
| ) |
| |
| messages = [ |
| {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"}, |
| {"role": "user", "content": "Who are you?"}, |
| ] |
| |
| input_ids = tokenizer.apply_chat_template( |
| messages, |
| add_generation_prompt=True, |
| return_tensors="pt" |
| ).to(model.device) |
| |
| terminators = [ |
| tokenizer.eos_token_id |
| ] |
| |
| outputs = model.generate( |
| input_ids, |
| max_new_tokens=256, |
| eos_token_id=terminators, |
| do_sample=True, |
| temperature=0.6, |
| top_p=0.9, |
| ) |
| response = outputs[0][input_ids.shape[-1]:] |
| print(tokenizer.decode(response, skip_special_tokens=True)) |
| ``` |
|
|
| ## Experiment Parameters |
|
|
| | **Parameter** | **Mistral-Base(7B)** | |
| | ------------------- | -------------------- | |
| | `GPU` | 8×Ascend910B | |
| | `beta` | 0.01 | |
| | `batch` | 128 | |
| | `learning_rate` | 5e-7 | |
| | `max_prompt_length` | 512 | |
| | `max_length` | 1024 | |
| | `num_train_epochs` | 1 | |
| | `torch_dtype` | `bfloat16` | |
| | `warmup_ratio` | 0.1 | |
| | `β_w` | 0.01 | |
| | `β_l` | 0.1 | |
| | `λ` | 0.1 | |
|
|
|
|
| ## Training Data |
|
|
| We use the [HuggingFaceH4/ultrafeedback_binarized](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized) dataset to train the Mistral Base model. |
|
|
|
|
| ## Benchmarks |
|
|
| <table> |
| <tr> |
| <th>Method</th> |
| <th colspan="3" style="text-align: center;">AlpacaEval 2.0</th> |
| </tr> |
| <tr> |
| <th></th> |
| <th>LC</th> |
| <th>WR</th> |
| <th>Avg. Len</th> |
| </tr> |
| <tr> |
| <td><b>RSPO</b></td> |
| <td><b>25.4</b></td> |
| <td><b>23.7</b></td> |
| <td>1873</td> |
| </tr> |
| </table> |
| |
|
|
|
|
| | **Method** | **GSM8K** | **ARC** | **TQA** | **MMLU** | **IFEval** | **Avg.** | |
| | ---------- | --------- | --------- | --------- | --------- | ---------- | --------- | |
| | **SFT** | **42.61** | 55.97 | 28.15 | 57.17 | 36.59 | 44.10 | |
| | **DPO** | 33.13 | 59.64 | 46.14 | 57.46 | 50.48 | 49.37 | |
| | **R-DPO** | 30.10 | 56.06 | 40.64 | 58.48 | 53.24 | 47.70 | |
| | **SimPO** | 33.59 | **60.15** | 43.45 | 58.25 | 52.98 | 49.68 | |
| | **WPO** | 30.63 | 57.00 | 40.51 | 58.54 | **55.64** | 48.46 | |
| | **RSPO** | 37.45 | 57.94 | **47.25** | **58.58** | 55.04 | **51.25** | |
|
|
|
|
|
|
|
|
|
|
|
|