| from unsloth import FastLanguageModel |
| import torch |
| from transformers import AutoTokenizer |
|
|
| max_seq_length = 4096 |
| dtype = torch.bfloat16 |
| load_in_4bit = True |
| model_name = '../out/pretrain-base' |
| output_dir = '../out/cpt-base' |
|
|
| model, tokenizer = FastLanguageModel.from_pretrained( |
| model_name=model_name, |
| max_seq_length=max_seq_length, |
| dtype=dtype, |
| load_in_4bit=load_in_4bit, |
| ) |
|
|
| print('Ignore loaded tokenizer by FastLanguageModel.from_pretrained and using AutoTokenizer.from_pretrained') |
| tokenizer = AutoTokenizer.from_pretrained('..', trust_remote_code=True, use_fast=True) |
|
|
| print(f'{model=}') |
| print(f'{tokenizer=}') |
|
|
| model = FastLanguageModel.get_peft_model( |
| model, |
| r=64, |
| target_modules=[ |
| "q_proj", "k_proj", "v_proj", "o_proj", |
| "gate_proj", "up_proj", "down_proj", |
| "embed_tokens", "lm_head", |
| ], |
| lora_alpha=16, |
| lora_dropout=0, |
| bias='none', |
| |
| use_gradient_checkpointing='unsloth', |
| random_state=23, |
| use_rslora=True, |
| loftq_config=None, |
| ) |
|
|
| print(f'{model=}') |
|
|
| from datasets import concatenate_datasets |
| from cpt_base_datasets import cpt_base_datasets |
| from cpt_instruct_datasets import cpt_instruct_datasets |
| from unsloth_utils import load_text_dataset, load_chat_dataset |
|
|
| core_datasets = [] |
|
|
| for dataset_config in cpt_base_datasets: |
| dataset = load_text_dataset(tokenizer, **dataset_config) |
| print(f'{dataset=}') |
| core_datasets.append(dataset) |
|
|
| |
| |
| |
| |
|
|
| final_dataset = concatenate_datasets(core_datasets) |
| print(f'{final_dataset=}') |
|
|
|
|
| from trl import SFTTrainer |
| from transformers import TrainingArguments |
| from unsloth import is_bfloat16_supported |
| from unsloth import UnslothTrainer, UnslothTrainingArguments |
|
|
|
|
| trainer = UnslothTrainer( |
| model=model, |
| tokenizer=tokenizer, |
| train_dataset=final_dataset, |
| dataset_text_field='text', |
| max_seq_length=max_seq_length, |
| dataset_num_proc=32, |
|
|
| args = UnslothTrainingArguments( |
| per_device_train_batch_size=8, |
| gradient_accumulation_steps=8, |
|
|
| warmup_ratio=0.1, |
| num_train_epochs=1, |
|
|
| learning_rate=5e-5, |
| embedding_learning_rate=5e-6, |
|
|
| fp16=not is_bfloat16_supported(), |
| bf16=is_bfloat16_supported(), |
| logging_steps=1, |
| optim='adamw_8bit', |
| weight_decay=0.01, |
| lr_scheduler_type='cosine', |
| seed=23, |
| output_dir=output_dir, |
| report_to='wandb', |
| ), |
| ) |
|
|
| trainer_stats = trainer.train() |
|
|