precis / scripts /train.py
compendious's picture
Changes
b813321
#!/usr/bin/env python3
"""CLI training script for Précis."""
import argparse
import logging
import sys
from pathlib import Path
# Add project root to path
sys.path.insert(0, str(Path(__file__).parent.parent))
from src.config import ModelConfig, TrainingConfig, DataConfig
from src.model import load_model, load_tokenizer, prepare_for_training
from src.tuning.lora import apply_lora
from src.tuning.data import create_dummy_data, prepare_dataset
from src.tuning.trainer import PrecisTrainer
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser(description="Train Précis summarization model")
parser.add_argument("--model-id", type=str, default=None, help="HuggingFace model ID")
parser.add_argument("--output-dir", type=str, default="./outputs", help="Output directory")
parser.add_argument("--epochs", type=int, default=3, help="Number of training epochs")
parser.add_argument("--batch-size", type=int, default=4, help="Batch size")
parser.add_argument("--learning-rate", type=float, default=2e-4, help="Learning rate")
parser.add_argument("--lora-r", type=int, default=16, help="LoRA rank")
parser.add_argument("--dry-run", action="store_true", help="Validate pipeline without training")
parser.add_argument("--dummy-samples", type=int, default=100, help="Number of dummy samples")
return parser.parse_args()
def main():
args = parse_args()
# Build configs
model_config = ModelConfig()
if args.model_id:
model_config.model_id = args.model_id
training_config = TrainingConfig(
output_dir=args.output_dir,
num_epochs=args.epochs,
batch_size=args.batch_size,
learning_rate=args.learning_rate,
lora_r=args.lora_r,
)
data_config = DataConfig()
if args.dry_run:
logger.info("=== DRY RUN MODE ===")
logger.info(f"Model: {model_config.model_id}")
logger.info(f"Output: {training_config.output_dir}")
logger.info(f"Epochs: {training_config.num_epochs}, Batch: {training_config.batch_size}")
logger.info(f"LoRA r: {training_config.lora_r}, alpha: {training_config.lora_alpha}")
# Test data pipeline only
dummy_data = create_dummy_data(5)
logger.info(f"Dummy data sample: {dummy_data[0]}")
logger.info("Dry run complete. Pipeline validated.")
return
# Load model and tokenizer
logger.info("Loading model and tokenizer...")
tokenizer = load_tokenizer(model_config)
model = load_model(model_config)
model = prepare_for_training(model)
model = apply_lora(model, training_config)
# Prepare data
logger.info("Preparing training data...")
train_data = create_dummy_data(args.dummy_samples)
train_dataset = prepare_dataset(train_data, tokenizer, data_config)
# Train
trainer = PrecisTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=train_dataset,
config=training_config,
)
trainer.train()
trainer.save()
logger.info("Training complete!")
if __name__ == "__main__":
main()