| |
| from transformers import AutoModelForCausalLM, TrainingArguments, Trainer |
| from datasets import load_from_disk |
|
|
| tokenized_dataset = load_from_disk("tokenized_dataset") |
|
|
| model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2") |
|
|
| training_args = TrainingArguments( |
| output_dir="./checkpoints", |
| num_train_epochs=1, |
| per_device_train_batch_size=1, |
| gradient_accumulation_steps=8, |
| evaluation_strategy="no", |
| save_strategy="epoch", |
| fp16=True, |
| logging_steps=50, |
| ) |
|
|
| trainer = Trainer( |
| model=model, |
| args=training_args, |
| train_dataset=tokenized_dataset, |
| ) |
|
|
| trainer.train() |
| model.save_pretrained("./my_ai_assistant", safe_serialization=True) |
|
|