| #!/usr/bin/env bash |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| export WANDB_API_KEY="f6021dca133c93e80a7dae4620bd335d4d08cac6" |
|
|
| if [ -z "${BASH_VERSION}" ]; then |
| echo "Please use bash to run this script." >&2 |
| exit 1 |
| fi |
|
|
| set -x |
|
|
| SCRIPT_DIR="$(cd "$(dirname "$0")" &>/dev/null && pwd)" |
| ROOT_DIR="$(dirname "${SCRIPT_DIR}")" |
| export PYTHONPATH="${ROOT_DIR}${PYTHONPATH:+:${PYTHONPATH}}" |
| export LOGLEVEL="${LOGLEVEL:-WARNING}" |
|
|
| MODEL_NAME_OR_PATH="huggyllama/llama-7b" |
| OUTPUT_DIR="${ROOT_DIR}/output/sft" |
| ZERO_STAGE=3 |
| while [[ "$#" -gt 0 ]]; do |
| arg="$1" |
| shift |
| case "${arg}" in |
| --model_name_or_path) |
| MODEL_NAME_OR_PATH="$1" |
| shift |
| ;; |
| --model_name_or_path=*) |
| MODEL_NAME_OR_PATH="${arg#*=}" |
| ;; |
| --output_dir) |
| OUTPUT_DIR="$1" |
| shift |
| ;; |
| --output_dir=*) |
| OUTPUT_DIR="${arg#*=}" |
| ;; |
| --zero_stage) |
| ZERO_STAGE="$1" |
| shift |
| ;; |
| --zero_stage=*) |
| ZERO_STAGE="${arg#*=}" |
| ;; |
| *) |
| echo "Unknown parameter passed: '${arg}'" >&2 |
| exit 1 |
| ;; |
| esac |
| done |
|
|
| mkdir -p "${OUTPUT_DIR}" |
| OUTPUT_DIR="$(cd "${OUTPUT_DIR}" &>/dev/null && pwd)" |
| if [[ ! -f "${OUTPUT_DIR}/.gitignore" ]]; then |
| echo '*' >"${OUTPUT_DIR}/.gitignore" |
| fi |
|
|
| cp -f "$0" "${OUTPUT_DIR}/script.sh" |
|
|
| if [[ -z "${WANDB_API_KEY}" ]]; then |
| export WANDB_MODE="offline" |
| fi |
|
|
| MASTER_PORT_START=10000 |
| MASTER_PORT_END=65535 |
| MASTER_PORT="$( |
| comm -23 \ |
| <(seq "${MASTER_PORT_START}" "${MASTER_PORT_END}" | sort) \ |
| <(ss -Htan | awk '{ print $4 }' | awk -F ':' '{ print $NF }' | sort -u) | |
| shuf | head -n 1 |
| )" |
|
|
| exec 1> >(tee "${OUTPUT_DIR}/stdout.log" >&1) 2> >(tee "${OUTPUT_DIR}/stderr.log" >&2) |
|
|
| deepspeed --num_nodes=1 --num_gpus=4 \ |
| --master_port "${MASTER_PORT}" \ |
| --module safe_rlhf.finetune \ |
| --train_datasets alpaca \ |
| --model_name_or_path "${MODEL_NAME_OR_PATH}" \ |
| --max_length 512 \ |
| --trust_remote_code True \ |
| --epochs 3 \ |
| --per_device_train_batch_size 4 \ |
| --per_device_eval_batch_size 4 \ |
| --gradient_accumulation_steps 16 \ |
| --gradient_checkpointing \ |
| --learning_rate 2e-5 \ |
| --lr_scheduler_type cosine \ |
| --lr_warmup_ratio 0.03 \ |
| --weight_decay 0.0 \ |
| --seed 42 \ |
| --output_dir "${OUTPUT_DIR}" \ |
| --log_type wandb \ |
| --log_project Safe-RLHF-SFT \ |
| --zero_stage "${ZERO_STAGE}" \ |
| --bf16 True \ |
| --tf32 True |
|
|