""" Comprehensive training launcher for step-level CoT vision model. This script orchestrates the full training pipeline: 1. Add special tokens to base model 2. Prepare datasets (FLUX-Reason-6M + FineVision) 3. Run SFT + PRM + RL training phases 4. Validate and save final model Usage: python train_qwen_cot_dual.py --config train_configs/train_qwen_cot_dual.yaml For specific phases only: python train_qwen_cot_dual.py --config CONFIG --phases sft prm """ import argparse import os import sys from pathlib import Path import yaml import torch from typing import Dict, List, Optional, Any import logging # Clone repository if running in HF Jobs environment if str(Path(__file__).parent) == "/tmp": from huggingface_hub import snapshot_download import shutil print("Running in HF Jobs - downloading repository...") # Clear any existing cache cache_dir = "/tmp/dei-model" if os.path.exists(cache_dir): print(f"Clearing existing cache at {cache_dir}") shutil.rmtree(cache_dir) repo_path = snapshot_download( repo_id="Mulebot/dei-model", repo_type="model", local_dir=cache_dir, token=os.environ.get("HF_TOKEN"), revision="main", # Force latest version force_download=True, # Force fresh download, ignore cache local_dir_use_symlinks=False, # Copy files instead of symlinks ) repo_root = Path(repo_path) print(f"Repository downloaded to: {repo_root}") # Change working directory to repository root os.chdir(repo_root) sys.path.insert(0, str(repo_root)) print(f"Changed working directory to: {os.getcwd()}") else: repo_root = Path(__file__).parent.absolute() # Add repository root and src to Python path sys.path.insert(0, str(repo_root)) sys.path.insert(0, str(repo_root / "src")) from src.reasoning.step_level_cot import StepLevelCoTTrainer from src.reasoning.flux_reason_dataset import FLUXReasonDataset, create_flux_reason_dataloader from src.reasoning.finevision_dataset import FineVisionDataset, create_finevision_dataloader from src.reasoning.step_data import ReasoningChain from utils.add_special_tokens import add_special_tokens_to_model # Set up logging logging.basicConfig( format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO ) logger = logging.getLogger(__name__) def load_config(config_path: str) -> Dict[str, Any]: """Load training configuration from YAML file.""" # If running in HF Jobs and path is relative, look in repository if not Path(config_path).is_absolute() and not Path(config_path).exists(): if repo_root and (repo_root / config_path).exists(): config_path = str(repo_root / config_path) with open(config_path, 'r') as f: config = yaml.safe_load(f) return config def prepare_model_with_tokens(config: Dict[str, Any]) -> str: """ Add special tokens to the base model and save to output directory. Args: config: Training configuration Returns: Path to the model with special tokens added """ logger.info("=" * 60) logger.info("STEP 1: Adding Special Tokens to Base Model") logger.info("=" * 60) model_path = config['pretrained_model_name_or_path'] output_dir = Path(config['output_dir']) / 'model_with_tokens' if output_dir.exists() and (output_dir / 'config.json').exists(): logger.info(f"Model with tokens already exists at {output_dir}") logger.info("Skipping token addition step.") return str(output_dir) logger.info(f"Base model: {model_path}") logger.info(f"Output directory: {output_dir}") logger.info("Adding 27 special tokens (15 base + 12 region tokens)...") try: add_special_tokens_to_model( model_path=model_path, output_dir=str(output_dir), ) logger.info("✓ Successfully added special tokens!") return str(output_dir) except Exception as e: logger.error(f"✗ Failed to add special tokens: {e}") logger.info("Continuing with base model (tokens will be added by trainer)") return model_path def prepare_datasets(config: Dict[str, Any], tokenizer: Any) -> Dict[str, Any]: """ Prepare FLUX-Reason-6M and FineVision datasets. Args: config: Training configuration tokenizer: Model tokenizer Returns: Dictionary with dataloaders for each dataset """ logger.info("=" * 60) logger.info("STEP 2: Preparing Datasets") logger.info("=" * 60) datasets_config = config.get('datasets', []) data_config = config.get('data_config', {}) dataloaders = {} dataset_weights = {} for dataset_cfg in datasets_config: dataset_name = dataset_cfg['name'] dataset_type = dataset_cfg['type'] weight = dataset_cfg.get('weight', 1.0) logger.info(f"\nLoading {dataset_name} ({dataset_type})...") logger.info(f" Weight: {weight}") try: if dataset_type == 'flux_reason_6m': dataloader = create_flux_reason_dataloader( tokenizer=tokenizer, batch_size=data_config.get('train_batch_size', 2), num_workers=data_config.get('num_workers', 4), streaming=dataset_cfg.get('streaming', True), max_samples=dataset_cfg.get('max_samples'), dataset_name=dataset_cfg['dataset_name'], split=dataset_cfg.get('split', 'train'), image_key=dataset_cfg.get('image_key', 'image'), reasoning_key=dataset_cfg.get('reasoning_key', 'reasoning'), question_key=dataset_cfg.get('question_key', 'question'), answer_key=dataset_cfg.get('answer_key', 'answer'), ) dataloaders[dataset_name] = dataloader dataset_weights[dataset_name] = weight logger.info(f"✓ Loaded {dataset_name}") elif dataset_type == 'finevision': dataloader = create_finevision_dataloader( tokenizer=tokenizer, batch_size=data_config.get('train_batch_size', 2), num_workers=data_config.get('num_workers', 4), streaming=dataset_cfg.get('streaming', True), max_samples=dataset_cfg.get('max_samples'), dataset_name=dataset_cfg['dataset_name'], subset=dataset_cfg.get('subset'), split=dataset_cfg.get('split', 'train'), image_key=dataset_cfg.get('image_key', 'image'), text_key=dataset_cfg.get('text_key', 'text'), generate_synthetic_reasoning=dataset_cfg.get('generate_synthetic_reasoning', True), num_synthetic_steps=dataset_cfg.get('num_synthetic_steps', 3), ) dataloaders[dataset_name] = dataloader dataset_weights[dataset_name] = weight logger.info(f"✓ Loaded {dataset_name}") else: logger.warning(f"Unknown dataset type: {dataset_type}") except Exception as e: logger.error(f"✗ Failed to load {dataset_name}: {e}") logger.info("Continuing without this dataset...") logger.info(f"\n✓ Prepared {len(dataloaders)} datasets") return {'dataloaders': dataloaders, 'weights': dataset_weights} def run_training_phase( trainer: StepLevelCoTTrainer, phase: str, phase_config: Dict[str, Any], dataloaders: Dict[str, Any], ) -> None: """ Run a specific training phase (SFT, PRM, or RL). Args: trainer: StepLevelCoTTrainer instance phase: Training phase name ('sft', 'prm', or 'rl') phase_config: Configuration for this phase dataloaders: Dataset dataloaders """ logger.info("=" * 60) logger.info(f"STEP 3: Running {phase.upper()} Training Phase") logger.info("=" * 60) if not phase_config.get('enabled', False): logger.info(f"{phase.upper()} phase is disabled in config. Skipping...") return max_steps = phase_config.get('max_train_steps', 5000) learning_rate = phase_config.get('learning_rate', 1e-5) logger.info(f"Phase: {phase.upper()}") logger.info(f"Max steps: {max_steps}") logger.info(f"Learning rate: {learning_rate}") try: if phase == 'sft': logger.info("\nRunning Supervised Fine-Tuning...") trainer.train_sft( max_steps=max_steps, learning_rate=learning_rate, ) logger.info("✓ SFT phase complete!") elif phase == 'prm': logger.info("\nTraining Process Reward Model...") prm_config = phase_config.get('prm_config', {}) trainer.train_prm( max_steps=max_steps, learning_rate=learning_rate, **prm_config, ) logger.info("✓ PRM training complete!") elif phase == 'rl': logger.info("\nRunning Reinforcement Learning with PPO...") ppo_config = phase_config.get('ppo_config', {}) trainer.train_rl( max_steps=max_steps, learning_rate=learning_rate, **ppo_config, ) logger.info("✓ RL training complete!") else: logger.warning(f"Unknown phase: {phase}") except Exception as e: logger.error(f"✗ {phase.upper()} phase failed: {e}") logger.error("Stopping training pipeline.") raise def main(): parser = argparse.ArgumentParser(description="Train step-level CoT vision model") parser.add_argument( '--config', type=str, required=True, help='Path to training configuration YAML file' ) parser.add_argument( '--phases', nargs='+', choices=['sft', 'prm', 'rl', 'all'], default=['all'], help='Training phases to run (default: all)' ) parser.add_argument( '--skip_token_addition', action='store_true', help='Skip special token addition step (if already done)' ) parser.add_argument( '--resume', type=str, default=None, help='Resume from checkpoint directory' ) args = parser.parse_args() # Adjust config path if running in HF Jobs environment config_path = args.config if not Path(config_path).is_absolute() and not Path(config_path).exists(): if repo_root and (repo_root / config_path).exists(): config_path = str(repo_root / config_path) # Load configuration logger.info("=" * 60) logger.info("QWEN STEP-LEVEL COT TRAINING PIPELINE") logger.info("=" * 60) logger.info(f"Config: {config_path}") logger.info(f"Phases: {args.phases}") config = load_config(config_path) # Determine which phases to run if 'all' in args.phases: phases_to_run = ['sft', 'prm', 'rl'] else: phases_to_run = args.phases # Step 1: Add special tokens (optional, can be skipped) if args.skip_token_addition: logger.info("\nSkipping special token addition (--skip_token_addition)") model_path = config['pretrained_model_name_or_path'] else: model_path = prepare_model_with_tokens(config) # Update config with tokenized model path config['pretrained_model_name_or_path'] = model_path # Step 2: Initialize trainer logger.info("\n" + "=" * 60) logger.info("INITIALIZING TRAINER") logger.info("=" * 60) try: trainer = StepLevelCoTTrainer(config_path) logger.info("✓ Trainer initialized successfully!") except Exception as e: logger.error(f"✗ Failed to initialize trainer: {e}") return 1 # Step 3: Prepare datasets dataset_info = prepare_datasets(config, trainer.tokenizer) # Attach datasets to trainer trainer.dataloaders = dataset_info['dataloaders'] trainer.dataset_weights = dataset_info['weights'] # Step 4: Run training phases training_phases = config.get('training_phases', []) for phase_info in training_phases: phase_name = phase_info['phase'] if phase_name not in phases_to_run: logger.info(f"\nSkipping {phase_name.upper()} phase (not in --phases)") continue run_training_phase( trainer=trainer, phase=phase_name, phase_config=phase_info, dataloaders=dataset_info['dataloaders'], ) # Step 5: Final model save logger.info("\n" + "=" * 60) logger.info("STEP 4: Saving Final Model") logger.info("=" * 60) final_output_dir = Path(config['output_dir']) / 'final_model' final_output_dir.mkdir(parents=True, exist_ok=True) try: trainer.save_checkpoint(str(final_output_dir)) logger.info(f"✓ Final model saved to {final_output_dir}") except Exception as e: logger.error(f"✗ Failed to save final model: {e}") # Summary logger.info("\n" + "=" * 60) logger.info("TRAINING PIPELINE COMPLETE!") logger.info("=" * 60) logger.info(f"Model: {config['pretrained_model_name_or_path']}") logger.info(f"Phases completed: {phases_to_run}") logger.info(f"Output directory: {config['output_dir']}") logger.info(f"Final model: {final_output_dir}") logger.info("\nNext steps:") logger.info(" 1. Test inference: python inference_document_processor.py") logger.info(" 2. Test generation: python inference_image_generator.py") logger.info(" 3. Evaluate on benchmarks: python evaluate_reasoning.py") return 0 if __name__ == '__main__': exit_code = main() sys.exit(exit_code)