| | |
| | |
| |
|
| | """ |
| | 从 Llama 3.2-1B 权重初始化 LoopLlama 模型 |
| | """ |
| |
|
| | import os |
| | import torch |
| | from transformers import LlamaForCausalLM, LlamaConfig, LlamaTokenizer |
| | from configuration_llama import LlamaConfig |
| | from modeling_llama import LoopLlamaForCausalLM |
| |
|
| | def setup_loopllama_from_pretrained( |
| | source_model_path="meta-llama/Llama-3.2-1B", |
| | target_path="./", |
| | loop_times=2 |
| | ): |
| | """ |
| | 从预训练的 Llama 模型创建 LoopLlama 模型 |
| | |
| | Args: |
| | source_model_path: 源 Llama 模型路径 |
| | target_path: 目标保存路径 |
| | loop_times: 循环次数 |
| | """ |
| | print(f"Loading original Llama model from {source_model_path}...") |
| | |
| | |
| | original_model = LlamaForCausalLM.from_pretrained( |
| | source_model_path, |
| | torch_dtype=torch.bfloat16, |
| | trust_remote_code=True |
| | ) |
| | original_config = original_model.config |
| | |
| | |
| | try: |
| | tokenizer = LlamaTokenizer.from_pretrained(source_model_path) |
| | except: |
| | from transformers import AutoTokenizer |
| | tokenizer = AutoTokenizer.from_pretrained(source_model_path) |
| | |
| | print("Creating LoopLlama configuration...") |
| | |
| | |
| | loop_config = LlamaConfig( |
| | **original_config.to_dict(), |
| | loop_times=loop_times |
| | ) |
| | |
| | print(f"Creating LoopLlama model with {loop_times} loop times...") |
| | |
| | |
| | loop_model = LoopLlamaForCausalLM(loop_config) |
| | |
| | print("Copying weights from original model...") |
| | |
| | |
| | original_state_dict = original_model.state_dict() |
| | loop_state_dict = loop_model.state_dict() |
| | |
| | |
| | for key in loop_state_dict.keys(): |
| | if key in original_state_dict: |
| | print(f"Copying {key}") |
| | loop_state_dict[key].copy_(original_state_dict[key]) |
| | else: |
| | print(f"Warning: {key} not found in original model") |
| | |
| | print(f"Saving LoopLlama model to {target_path}...") |
| | |
| | |
| | loop_model.save_pretrained(target_path) |
| | loop_config.save_pretrained(target_path) |
| | tokenizer.save_pretrained(target_path) |
| | |
| | print("Setup completed!") |
| | |
| | |
| | print("Verifying model loading...") |
| | test_model = LoopLlamaForCausalLM.from_pretrained( |
| | target_path, |
| | trust_remote_code=True, |
| | torch_dtype=torch.bfloat16 |
| | ) |
| | print(f"Model loaded successfully. Loop times: {test_model.config.loop_times}") |
| | |
| | return loop_model, tokenizer |
| |
|
| | if __name__ == "__main__": |
| | import argparse |
| | |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--source", default="/9950backfile/zjy_2/loopllama_cpt/loopllama-cpt/models/llama3_2-1B", help="Source Llama model") |
| | parser.add_argument("--target", default="./", help="Target directory") |
| | parser.add_argument("--loop_times", type=int, default=3, help="Number of loop times") |
| | |
| | args = parser.parse_args() |
| | |
| | setup_loopllama_from_pretrained( |
| | source_model_path=args.source, |
| | target_path=args.target, |
| | loop_times=args.loop_times |
| | ) |