{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/mxl/anaconda3/envs/my_llm/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, set_seed" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading checkpoint shards: 100%|██████████| 4/4 [00:04<00:00, 1.17s/it]\n", "WARNING:root:A model is loaded from '/data/public/model/Meta-Llama-3-8B-Instruct/', and no v_head weight is found. This IS expected if you are not resuming PPO training.\n" ] } ], "source": [ "model = AutoModelForCausalLMWithValueHead.from_pretrained(\n", " '/data/public/model/Meta-Llama-3-8B-Instruct/',\n", " load_in_8bit=True,\n", " # peft_config=lora_config,\n", " device_map='auto',\n", " )" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "LlamaDecoderLayer(\n", " (self_attn): LlamaSdpaAttention(\n", " (q_proj): Linear8bitLt(in_features=4096, out_features=4096, bias=False)\n", " (k_proj): Linear8bitLt(in_features=4096, out_features=1024, bias=False)\n", " (v_proj): Linear8bitLt(in_features=4096, out_features=1024, bias=False)\n", " (o_proj): Linear8bitLt(in_features=4096, out_features=4096, bias=False)\n", " (rotary_emb): LlamaRotaryEmbedding()\n", " )\n", " (mlp): LlamaMLP(\n", " (gate_proj): Linear8bitLt(in_features=4096, out_features=14336, bias=False)\n", " (up_proj): Linear8bitLt(in_features=4096, out_features=14336, bias=False)\n", " (down_proj): Linear8bitLt(in_features=14336, out_features=4096, bias=False)\n", " (act_fn): SiLU()\n", " )\n", " (input_layernorm): LlamaRMSNorm()\n", " (post_attention_layernorm): LlamaRMSNorm()\n", ")" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.pretrained_model.model.layers[1]" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/data3/mxl/anaconda3/envs/my_llm/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "from datasets import load_dataset" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Generating train split: 160800 examples [00:00, 246619.83 examples/s]\n", "Generating test split: 8552 examples [00:00, 191860.63 examples/s]\n" ] } ], "source": [ "ds = load_dataset('./datasets/hh-rlhf/', split='train')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'sfasdf [1, 2, 3]'" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "f\"sfasdf {}\"" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'1_2_3'" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "'_'.join(map(str, [1, 2, 3]))\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "RiC", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 2 }