""" Step-Level Chain of Thought Trainer for Vision-Language Models. Integrates PRM, RL training, and inference-time scaling for fine-grained reasoning. """ import argparse import logging from pathlib import Path from typing import Dict, Optional, Any import torch import yaml from accelerate import Accelerator from accelerate.logging import get_logger from transformers import AutoTokenizer from diffusers import QwenImagePipeline, QwenImageTransformer2DModel from omegaconf import OmegaConf from tqdm.auto import tqdm from .step_data import StepDataset, ReasoningChain, SPECIAL_TOKENS from .prm import ProcessRewardModel, PRMTrainer from .rl_trainer import RLReasoningTrainer, RLConfig from .inference_scaling import InferenceTimeScaling, InferenceConfig logger = get_logger(__name__, log_level="INFO") class StepLevelCoTTrainer: """ Complete training pipeline for step-level chain-of-thought reasoning. Training phases: 1. Supervised Fine-tuning (SFT) on reasoning chains 2. PRM training on step-level rewards 3. RL fine-tuning with PRM rewards 4. Inference-time scaling evaluation """ def __init__( self, config_path: str, device: str = "cuda", ): """ Initialize trainer. Args: config_path: Path to training configuration YAML device: Device for training """ self.device = device # Load configuration with open(config_path, 'r') as f: config_dict = yaml.safe_load(f) self.config = OmegaConf.create(config_dict) # Initialize accelerator self.accelerator = Accelerator( gradient_accumulation_steps=self.config.gradient_accumulation_steps, mixed_precision=self.config.mixed_precision, log_with=self.config.report_to, project_dir=self.config.output_dir, ) # Load model and tokenizer self._load_model() # Initialize PRM self.prm = ProcessRewardModel( vision_dim=self.config.vision_dim, text_dim=self.config.text_dim, hidden_dim=self.config.prm_hidden_dim, num_heads=self.config.prm_num_heads, dropout=self.config.prm_dropout, max_steps=self.config.max_reasoning_steps, ) # Initialize RL trainer with explicit type conversion rl_config = RLConfig( clip_epsilon=float(self.config.ppo_clip_epsilon), value_loss_coef=float(self.config.value_loss_coef), entropy_coef=float(self.config.entropy_coef), gamma=float(self.config.gamma), gae_lambda=float(self.config.gae_lambda), learning_rate=float(self.config.rl_learning_rate), batch_size=int(self.config.train_batch_size), use_prm_rewards=True, normalize_rewards=True, ) self.rl_trainer = RLReasoningTrainer( policy_model=self.model, prm_model=self.prm, config=rl_config, device=device, ) # Initialize inference scaler # Ensure numeric types are properly converted try: num_samples = int(self.config.num_inference_samples) temperature = float(self.config.inference_temperature) aggregation = str(self.config.aggregation_method) inference_config = InferenceConfig( num_samples=num_samples, temperature=temperature, aggregation=aggregation, use_prm_scores=True, ) self.inference_scaler = InferenceTimeScaling( model=self.model, prm=self.prm, config=inference_config, device=device, ) except Exception as e: logger.error(f"Failed to initialize inference scaler: {e}") logger.error(f"Config values - num_samples: {self.config.num_inference_samples} (type: {type(self.config.num_inference_samples)})") logger.error(f"Config values - temperature: {self.config.inference_temperature} (type: {type(self.config.inference_temperature)})") logger.error(f"Config values - aggregation: {self.config.aggregation_method} (type: {type(self.config.aggregation_method)})") raise logger.info("StepLevelCoTTrainer initialized") def _load_model(self) -> None: """Load vision-language model (supports both VLM and diffusion models).""" logger.info(f"Loading model: {self.config.pretrained_model_name_or_path}") model_name = self.config.pretrained_model_name_or_path.lower() # Detect model type is_vlm = any(x in model_name for x in ["qwen2", "qwen-vl", "internvl", "llava", "cogvlm"]) is_diffusion = any(x in model_name for x in ["flux", "qwen-image", "stable-diffusion", "sd"]) if is_vlm: # Load as Vision-Language Model from transformers import AutoProcessor, AutoModelForVision2Seq, BitsAndBytesConfig logger.info("Loading as Vision-Language Model...") # Prepare quantization config if enabled quantization_config = None if self.config.get("quantize", False): quantization_config = BitsAndBytesConfig( load_in_8bit=True, llm_int8_threshold=6.0, llm_int8_has_fp16_weight=False, ) # Load processor and model self.tokenizer = AutoProcessor.from_pretrained( self.config.pretrained_model_name_or_path, trust_remote_code=True ) torch_dtype = torch.float16 if self.config.mixed_precision == "fp16" else torch.bfloat16 self.model = AutoModelForVision2Seq.from_pretrained( self.config.pretrained_model_name_or_path, torch_dtype=torch_dtype, quantization_config=quantization_config, trust_remote_code=True, device_map="auto" ) # For VLMs, model itself is the main component self.pipeline = None self.vae = None self.text_encoder = None elif is_diffusion: # Load as Diffusion Model logger.info("Loading as Diffusion Model...") self.pipeline = QwenImagePipeline.from_pretrained( self.config.pretrained_model_name_or_path, torch_dtype=torch.float16 if self.config.mixed_precision == "fp16" else torch.float32, ) self.model = self.pipeline.transformer self.vae = self.pipeline.vae self.text_encoder = self.pipeline.text_encoder self.tokenizer = self.pipeline.tokenizer else: # Try loading as VLM by default logger.warning(f"Unknown model type, attempting to load as VLM...") from transformers import AutoProcessor, AutoModelForVision2Seq self.tokenizer = AutoProcessor.from_pretrained( self.config.pretrained_model_name_or_path, trust_remote_code=True ) self.model = AutoModelForVision2Seq.from_pretrained( self.config.pretrained_model_name_or_path, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto" ) self.pipeline = None self.vae = None self.text_encoder = None # Add special tokens for chain-of-thought reasoning self._add_special_tokens() # Enable gradient checkpointing if available if self.config.gradient_checkpointing: if hasattr(self.model, 'enable_gradient_checkpointing'): self.model.enable_gradient_checkpointing() elif hasattr(self.model, 'gradient_checkpointing_enable'): self.model.gradient_checkpointing_enable() else: logger.warning("Gradient checkpointing requested but not available for this model") logger.info("Model loaded successfully") def _add_special_tokens(self) -> None: """Add special tokens for chain-of-thought reasoning.""" from .step_data import SPECIAL_TOKENS, SPECIAL_TOKEN_DESCRIPTIONS logger.info(f"Adding {len(SPECIAL_TOKENS)} special tokens for chain-of-thought reasoning...") # Get actual tokenizer (handle both Tokenizer and Processor objects) if hasattr(self.tokenizer, 'tokenizer'): # This is a processor, extract the tokenizer actual_tokenizer = self.tokenizer.tokenizer else: # This is already a tokenizer actual_tokenizer = self.tokenizer # Add special tokens to tokenizer num_added_toks = actual_tokenizer.add_special_tokens( {"additional_special_tokens": SPECIAL_TOKENS} ) if num_added_toks > 0: logger.info(f"Added {num_added_toks} new tokens. Resizing model embeddings...") # Get current embedding size old_size = self.model.get_input_embeddings().weight.size(0) # Resize token embeddings (use actual tokenizer's vocab size) new_size = len(actual_tokenizer) self.model.resize_token_embeddings(new_size) # Initialize new token embeddings with semantic initialization self._initialize_new_token_embeddings(num_added_toks, SPECIAL_TOKEN_DESCRIPTIONS) logger.info(f"Resized embeddings from {old_size} to {new_size}") else: logger.info("Special tokens already present in tokenizer.") def _initialize_new_token_embeddings(self, num_added_toks: int, descriptions: list) -> None: """Initialize newly added token embeddings with semantic information.""" # Get actual tokenizer (handle both Tokenizer and Processor objects) if hasattr(self.tokenizer, 'tokenizer'): actual_tokenizer = self.tokenizer.tokenizer else: actual_tokenizer = self.tokenizer with torch.no_grad(): input_embeddings = self.model.get_input_embeddings() # Check if model has language model head has_lm_head = hasattr(self.model, 'lm_head') if has_lm_head: lm_head = self.model.lm_head for i, description in enumerate(reversed(descriptions), start=1): # Tokenize the description using actual tokenizer tokenized = actual_tokenizer.tokenize(description) if not tokenized: continue tokenized_ids = actual_tokenizer.convert_tokens_to_ids(tokenized) # Average embeddings of description tokens new_embedding = input_embeddings.weight[tokenized_ids].mean(dim=0) input_embeddings.weight[-i, :] = new_embedding.clone().detach() # Initialize lm_head if it exists if has_lm_head: new_lm_head = lm_head.weight[tokenized_ids].mean(dim=0) lm_head.weight[-i, :] = new_lm_head.clone().detach() logger.info(f"Initialized {num_added_toks} new token embeddings with semantic information.") def load_step_dataset( self, data_dir: str, split: str = "train", ) -> StepDataset: """ Load step-level reasoning dataset. Args: data_dir: Directory containing reasoning chain JSON files split: Dataset split (train/val/test) Returns: StepDataset instance """ data_path = Path(data_dir) / split logger.info(f"Loading dataset from {data_path}") dataset = StepDataset.from_json_dir( str(data_path), tokenizer=self.tokenizer, max_steps=self.config.max_reasoning_steps, include_visual_features=True, filter_incorrect=False, # Include all chains for learning ) stats = dataset.get_statistics() logger.info(f"Dataset loaded: {stats}") return dataset def train_sft( self, max_steps: int = 10000, learning_rate: float = 2e-5, save_steps: int = 1000, eval_steps: int = 500, ) -> None: """ Supervised Fine-Tuning phase for the base model. Args: max_steps: Maximum training steps learning_rate: Learning rate for SFT save_steps: Save checkpoint every N steps eval_steps: Evaluate every N steps """ logger.info("Starting Supervised Fine-Tuning (SFT)") logger.info(f"Max steps: {max_steps}, LR: {learning_rate}") # Import dataset loaders from .flux_reason_dataset import FLUXReasonDataset from .finevision_dataset import FineVisionDataset from torch.utils.data import DataLoader # Get batch size with default batch_size = getattr(self.config, 'train_batch_size', 1) # Load FLUX-Reason-6M dataset (70% weight) logger.info("Loading FLUX-Reason-6M dataset...") flux_dataset = FLUXReasonDataset( dataset_name="LucasFang/FLUX-Reason-6M", split="train", streaming=True, max_samples=int(max_steps * batch_size * 0.7), ) # Load FineVision dataset (30% weight) logger.info("Loading FineVision dataset...") finevision_dataset = FineVisionDataset( dataset_name="HuggingFaceM4/FineVision", subset="llavar_gpt4_20k", # Use LLaVA subset as default split="train", streaming=True, max_samples=int(max_steps * batch_size * 0.3), generate_synthetic_reasoning=True, num_synthetic_steps=3, ) # Collate function to convert ReasoningChain to model inputs def collate_fn(batch): # Filter out None samples batch = [b for b in batch if b is not None] if not batch: # Return empty batch structure instead of None # This prevents Accelerate DataLoader from crashing return { 'input_ids': torch.tensor([[]]), 'attention_mask': torch.tensor([[]]), 'labels': torch.tensor([[]]), } # Extract text and images texts = [] images = [] for chain in batch: # Format: question + reasoning steps + answer text = f"Question: {chain.prompt}\n" text += "<|reasoning_start|>\n" for step in chain.steps: text += f"<|step_start|>{step.description}<|step_end|>\n" text += "<|reasoning_end|>\n" text += f"Answer: {chain.final_answer}" texts.append(text) # Handle missing images - append None which the processor should handle if hasattr(chain, 'image'): images.append(chain.image if chain.image is not None else None) else: images.append(None) # Tokenize with processor if hasattr(self.tokenizer, 'tokenizer'): actual_tokenizer = self.tokenizer.tokenizer else: actual_tokenizer = self.tokenizer # Use processor for vision-language models if hasattr(self.tokenizer, '__call__'): # Filter out None images - if all are None, pass None instead of list valid_images = [img for img in images if img is not None] images_param = valid_images if valid_images else None inputs = self.tokenizer( text=texts, images=images_param, return_tensors="pt", padding=True, truncation=True, max_length=2048, ) else: inputs = actual_tokenizer( texts, return_tensors="pt", padding=True, truncation=True, max_length=2048, ) # Set labels (same as input_ids for causal LM) inputs['labels'] = inputs['input_ids'].clone() return inputs # Interleave datasets with weighted sampling from itertools import cycle, islice def interleave_datasets(flux_ds, fv_ds, flux_weight=0.7): flux_iter = iter(flux_ds) fv_iter = iter(fv_ds) while True: # Sample from FLUX 70% of the time if torch.rand(1).item() < flux_weight: try: yield next(flux_iter) except StopIteration: flux_iter = iter(flux_ds) yield next(flux_iter) else: try: yield next(fv_iter) except StopIteration: fv_iter = iter(fv_ds) yield next(fv_iter) # Create interleaved dataset from torch.utils.data import IterableDataset class InterleavedDataset(IterableDataset): def __init__(self, ds1, ds2, weight1=0.7): self.ds1 = ds1 self.ds2 = ds2 self.weight1 = weight1 def __iter__(self): return interleave_datasets(self.ds1, self.ds2, self.weight1) mixed_dataset = InterleavedDataset(flux_dataset, finevision_dataset, 0.7) # Create dataloader train_dataloader = DataLoader( mixed_dataset, batch_size=batch_size, collate_fn=collate_fn, num_workers=0, ) # Setup optimizer optimizer = torch.optim.AdamW( self.model.parameters(), lr=float(learning_rate), weight_decay=getattr(self.config, 'weight_decay', 0.01), ) # Check for existing checkpoint to resume from resume_step = 0 checkpoint_path = self._find_latest_checkpoint("sft") if checkpoint_path: try: resume_step = self._load_checkpoint(checkpoint_path, "sft") logger.info(f"Resuming SFT training from step {resume_step}") except Exception as e: logger.warning(f"Failed to load checkpoint: {e}. Starting from scratch.") resume_step = 0 # Prepare with accelerator self.model, optimizer, train_dataloader = self.accelerator.prepare( self.model, optimizer, train_dataloader ) # Training loop self.model.train() global_step = resume_step total_loss = 0 from tqdm import tqdm progress_bar = tqdm( total=max_steps, initial=resume_step, desc="SFT Training", disable=not self.accelerator.is_local_main_process, ) for batch in train_dataloader: if batch is None: continue if global_step >= max_steps: break # Forward pass outputs = self.model( input_ids=batch['input_ids'], attention_mask=batch['attention_mask'], labels=batch['labels'], ) loss = outputs.loss # Backward pass self.accelerator.backward(loss) optimizer.step() optimizer.zero_grad() total_loss += loss.item() global_step += 1 progress_bar.update(1) progress_bar.set_postfix({'loss': loss.item(), 'avg_loss': total_loss / global_step}) # Save checkpoint if global_step % save_steps == 0: logger.info(f"Saving checkpoint at step {global_step}") self._save_checkpoint(global_step, "sft") # Logging logging_steps = getattr(self.config, 'logging_steps', 100) if global_step % logging_steps == 0: logger.info(f"Step {global_step}/{max_steps} - Loss: {loss.item():.4f}") progress_bar.close() logger.info(f"SFT training completed - Avg loss: {total_loss / global_step:.4f}") def train_prm( self, max_steps: int = 5000, learning_rate: float = 1e-5, save_steps: int = 500, eval_steps: int = 500, **kwargs, # Accept additional config like hidden_dim, num_layers, dropout, reward_scale ) -> None: """ Train Process Reward Model. Args: max_steps: Maximum training steps learning_rate: Learning rate for PRM training save_steps: Save checkpoint every N steps eval_steps: Evaluate every N steps **kwargs: Additional PRM configuration (hidden_dim, num_layers, dropout, reward_scale) """ logger.info("Starting PRM training") logger.info(f"Max steps: {max_steps}, LR: {learning_rate}") if kwargs: logger.info(f"PRM config: {kwargs}") # Import dataset loaders from .flux_reason_dataset import FLUXReasonDataset from torch.utils.data import DataLoader # Load FLUX-Reason-6M for PRM training (has reasoning chains) logger.info("Loading FLUX-Reason-6M for PRM training...") batch_size = getattr(self.config, 'train_batch_size', 1) train_dataset = FLUXReasonDataset( dataset_name="LucasFang/FLUX-Reason-6M", split="train", streaming=True, max_samples=max_steps * batch_size, ) train_dataloader = DataLoader( train_dataset, batch_size=batch_size, shuffle=False, num_workers=0, ) # Setup PRM trainer from .prm import PRMTrainer prm_trainer = PRMTrainer( model=self.prm, learning_rate=float(learning_rate), weight_decay=getattr(self.config, 'weight_decay', 0.01), warmup_steps=getattr(self.config, 'warmup_steps', 500), ) # Prepare with accelerator prm_trainer.model, prm_trainer.optimizer, train_dataloader = self.accelerator.prepare( prm_trainer.model, prm_trainer.optimizer, train_dataloader ) # Training loop prm_trainer.model.train() global_step = 0 total_loss = 0 from tqdm import tqdm progress_bar = tqdm( total=max_steps, desc="PRM Training", disable=not self.accelerator.is_local_main_process, ) for batch in train_dataloader: if global_step >= max_steps: break # Extract features for PRM # PRM scores individual reasoning steps try: # Get vision features and step embeddings from batch vision_features = batch.get('image_features', batch.get('pixel_values')) step_texts = batch.get('reasoning_steps', []) # For now, use a simple reward signal based on correct/incorrect target_rewards = torch.ones(len(step_texts), dtype=torch.float32) # Train step metrics = prm_trainer.train_step( vision_features, step_texts, target_rewards, ) total_loss += metrics['loss'] global_step += 1 progress_bar.update(1) progress_bar.set_postfix({'loss': metrics['loss'], 'avg_loss': total_loss / global_step}) # Save checkpoint if global_step % save_steps == 0: logger.info(f"Saving PRM checkpoint at step {global_step}") self._save_checkpoint(global_step, "prm") except Exception as e: logger.warning(f"PRM training step failed: {e}, skipping batch") continue progress_bar.close() logger.info(f"PRM training completed - Avg loss: {total_loss / max(global_step, 1):.4f}") def train_rl( self, max_steps: int = 8000, learning_rate: float = 5e-6, save_steps: int = 500, eval_steps: int = 500, **kwargs, # Accept additional config like gamma, lam, cliprange, etc. ) -> None: """ Train policy with reinforcement learning. Args: max_steps: Maximum training steps learning_rate: Learning rate for RL training save_steps: Save checkpoint every N steps eval_steps: Evaluate every N steps **kwargs: Additional PPO configuration (gamma, lam, cliprange, vf_coef, ent_coef, etc.) """ logger.info("Starting RL training") logger.info(f"Max steps: {max_steps}, LR: {learning_rate}") if kwargs: logger.info(f"PPO config: {kwargs}") # Import dataset loaders from .flux_reason_dataset import FLUXReasonDataset from torch.utils.data import DataLoader # Get batch size with default batch_size = getattr(self.config, 'train_batch_size', 1) # Load dataset for RL training logger.info("Loading FLUX-Reason-6M for RL training...") train_dataset = FLUXReasonDataset( dataset_name="LucasFang/FLUX-Reason-6M", split="train", streaming=True, max_samples=max_steps * batch_size, ) train_dataloader = DataLoader( train_dataset, batch_size=batch_size, shuffle=False, num_workers=0, ) # Training loop with PPO global_step = 0 total_reward = 0 from tqdm import tqdm progress_bar = tqdm( total=max_steps, desc="RL Training", disable=not self.accelerator.is_local_main_process, ) for batch in train_dataloader: if global_step >= max_steps: break try: # Generate reasoning chains with current policy # Score with PRM # Update policy with PPO # For now, use the RL trainer self.rl_trainer.policy.train() # Simple training step (actual PPO would be more complex) vision_features = batch.get('image_features', batch.get('pixel_values')) # Generate and score chains # This is simplified - full PPO would involve: # 1. Generate multiple chains # 2. Score with PRM # 3. Compute advantages # 4. Update policy # Placeholder reward (would come from PRM) reward = 1.0 total_reward += reward global_step += 1 progress_bar.update(1) progress_bar.set_postfix({ 'reward': reward, 'avg_reward': total_reward / global_step }) # Save checkpoint if global_step % save_steps == 0: logger.info(f"Saving RL checkpoint at step {global_step}") self._save_checkpoint(global_step, "rl") except Exception as e: logger.warning(f"RL training step failed: {e}, skipping batch") continue progress_bar.close() logger.info(f"RL training completed - Avg reward: {total_reward / max(global_step, 1):.4f}") def evaluate_inference_scaling( self, test_dataset: StepDataset, num_samples: int = 10, ) -> Dict[str, Any]: """ Evaluate inference-time scaling performance. Args: test_dataset: Test dataset num_samples: Number of test samples Returns: Evaluation metrics """ logger.info("Evaluating inference-time scaling") results = [] for idx in range(min(num_samples, len(test_dataset))): sample = test_dataset[idx] # Run inference with scaling best_chain, all_chains = self.inference_scaler( image_features=sample['visual_features'][0], prompt_ids=sample['prompt_ids'], prompt_mask=sample['prompt_mask'], image_path=sample['image_path'], prompt_text="", # Would decode from IDs ) results.append({ 'chain_id': sample['chain_id'], 'best_reward': best_chain.total_reward, 'num_steps': len(best_chain.steps), 'avg_reward_all': sum(c.total_reward for c in all_chains) / len(all_chains), 'std_reward_all': torch.tensor([c.total_reward for c in all_chains]).std().item(), }) # Aggregate metrics metrics = { 'avg_best_reward': sum(r['best_reward'] for r in results) / len(results), 'avg_num_steps': sum(r['num_steps'] for r in results) / len(results), 'improvement_over_mean': sum( (r['best_reward'] - r['avg_reward_all']) / (r['std_reward_all'] + 1e-8) for r in results ) / len(results), } logger.info(f"Inference scaling metrics: {metrics}") return metrics def _evaluate_prm( self, prm_trainer: PRMTrainer, val_dataset: StepDataset, ) -> Dict[str, float]: """Evaluate PRM on validation set.""" val_dataloader = torch.utils.data.DataLoader( val_dataset, batch_size=self.config.train_batch_size, shuffle=False, ) total_metrics = {'mse': 0, 'mae': 0, 'correlation': 0} for batch in val_dataloader: vision_features = batch['visual_features'] step_embeddings = batch['step_descriptions'] target_rewards = batch['step_rewards'] metrics = prm_trainer.evaluate( vision_features, step_embeddings, target_rewards, ) for key in total_metrics: total_metrics[key] += metrics[key] # Average for key in total_metrics: total_metrics[key] /= len(val_dataloader) return total_metrics def _save_checkpoint(self, step: int, phase: str) -> None: """Save checkpoint for a specific training phase and upload to Hub.""" save_path = Path(self.config.output_dir) / f"checkpoint-{phase}-{step}" save_path.mkdir(parents=True, exist_ok=True) if phase == "sft": # Save the main model self.accelerator.unwrap_model(self.model).save_pretrained(save_path) if hasattr(self, 'tokenizer'): self.tokenizer.save_pretrained(save_path) # Upload to HuggingFace Hub for persistence try: from huggingface_hub import HfApi api = HfApi() repo_id = getattr(self.config, 'hub_model_id', 'Mulebot/qwen-cot-checkpoints') logger.info(f"Uploading checkpoint to {repo_id}...") api.upload_folder( folder_path=str(save_path), repo_id=repo_id, path_in_repo=f"checkpoint-{phase}-{step}", repo_type="model", commit_message=f"Save {phase} checkpoint at step {step}", ) logger.info(f"Checkpoint uploaded to hub: {repo_id}/checkpoint-{phase}-{step}") except Exception as e: logger.warning(f"Failed to upload checkpoint to hub: {e}") elif phase == "prm": # Save PRM model torch.save( self.accelerator.unwrap_model(self.prm).state_dict(), save_path / "prm_model.pt", ) elif phase == "rl": # Save RL policy torch.save( self.accelerator.unwrap_model(self.rl_trainer.policy).state_dict(), save_path / "rl_policy.pt", ) logger.info(f"Checkpoint saved to {save_path}") def _find_latest_checkpoint(self, phase: str) -> Optional[Path]: """Find the latest checkpoint for a specific phase from Hub or local.""" # First, try to find checkpoints on HuggingFace Hub try: from huggingface_hub import HfApi, hf_hub_download api = HfApi() repo_id = getattr(self.config, 'hub_model_id', 'Mulebot/qwen-cot-checkpoints') logger.info(f"Checking for checkpoints on {repo_id}...") # List all files in the repo try: files = api.list_repo_files(repo_id, repo_type="model") # Find checkpoint directories checkpoint_dirs = set() for file in files: if f"checkpoint-{phase}-" in file: # Extract the checkpoint directory name parts = file.split("/") if len(parts) > 0: checkpoint_dirs.add(parts[0]) if checkpoint_dirs: # Extract step numbers checkpoint_steps = [] for ckpt_dir in checkpoint_dirs: try: step = int(ckpt_dir.split("-")[-1]) checkpoint_steps.append((step, ckpt_dir)) except ValueError: continue if checkpoint_steps: # Get the latest checkpoint latest_step, latest_dir = max(checkpoint_steps, key=lambda x: x[0]) logger.info(f"Found checkpoint on hub at step {latest_step}: {latest_dir}") # Download the checkpoint local_dir = Path(self.config.output_dir) / latest_dir local_dir.mkdir(parents=True, exist_ok=True) logger.info(f"Downloading checkpoint from hub...") from huggingface_hub import snapshot_download snapshot_download( repo_id=repo_id, repo_type="model", allow_patterns=f"{latest_dir}/*", local_dir=self.config.output_dir, local_dir_use_symlinks=False, ) logger.info(f"Checkpoint downloaded to {local_dir}") return local_dir except Exception as e: logger.info(f"No checkpoints found on hub or error accessing: {e}") except ImportError: logger.warning("huggingface_hub not available, skipping hub checkpoint check") # Fall back to local checkpoints checkpoint_dir = Path(self.config.output_dir) if not checkpoint_dir.exists(): return None # Find all checkpoints for this phase checkpoints = list(checkpoint_dir.glob(f"checkpoint-{phase}-*")) if not checkpoints: return None # Extract step numbers and find the latest checkpoint_steps = [] for ckpt in checkpoints: try: step = int(ckpt.name.split("-")[-1]) checkpoint_steps.append((step, ckpt)) except ValueError: continue if not checkpoint_steps: return None # Return checkpoint with highest step number latest_step, latest_path = max(checkpoint_steps, key=lambda x: x[0]) logger.info(f"Found local checkpoint at step {latest_step}: {latest_path}") return latest_path def _load_checkpoint(self, checkpoint_path: Path, phase: str) -> int: """Load checkpoint and return the step number.""" if phase == "sft": # Load model from checkpoint from transformers import AutoModelForVision2Seq logger.info(f"Loading SFT checkpoint from {checkpoint_path}") # Load the checkpoint model checkpoint_model = AutoModelForVision2Seq.from_pretrained( checkpoint_path, torch_dtype=torch.bfloat16, device_map=self.accelerator.device, ) # Copy weights to current model self.model.load_state_dict(checkpoint_model.state_dict()) del checkpoint_model elif phase == "prm": logger.info(f"Loading PRM checkpoint from {checkpoint_path}") self.prm.load_state_dict( torch.load(checkpoint_path / "prm_model.pt", map_location=self.accelerator.device) ) elif phase == "rl": logger.info(f"Loading RL checkpoint from {checkpoint_path}") self.rl_trainer.policy.load_state_dict( torch.load(checkpoint_path / "rl_policy.pt", map_location=self.accelerator.device) ) # Extract step number from checkpoint path step = int(checkpoint_path.name.split("-")[-1]) logger.info(f"Resumed from checkpoint at step {step}") return step def _save_prm(self, epoch: int) -> None: """Save PRM checkpoint.""" save_path = Path(self.config.output_dir) / "prm" / f"epoch_{epoch}" save_path.mkdir(parents=True, exist_ok=True) torch.save( self.prm.state_dict(), save_path / "prm_model.pt", ) logger.info(f"PRM saved to {save_path}") def save_checkpoint(self, step: int) -> None: """Save complete training checkpoint.""" save_path = Path(self.config.output_dir) / f"checkpoint-{step}" save_path.mkdir(parents=True, exist_ok=True) # Save policy self.accelerator.save_state(str(save_path / "policy")) # Save PRM torch.save(self.prm.state_dict(), save_path / "prm.pt") # Save RL trainer state self.rl_trainer.save_checkpoint(str(save_path / "rl_trainer.pt")) logger.info(f"Checkpoint saved to {save_path}") def main(): """Main training entry point.""" parser = argparse.ArgumentParser(description="Train step-level CoT reasoning for vision models") parser.add_argument("--config", type=str, required=True, help="Path to config YAML") parser.add_argument("--phase", type=str, choices=["prm", "rl", "inference", "all"], default="all") parser.add_argument("--data_dir", type=str, required=True, help="Path to reasoning chain data") args = parser.parse_args() # Initialize trainer trainer = StepLevelCoTTrainer( config_path=args.config, device="cuda" if torch.cuda.is_available() else "cpu", ) # Load datasets train_dataset = trainer.load_step_dataset(args.data_dir, split="train") val_dataset = trainer.load_step_dataset(args.data_dir, split="val") test_dataset = trainer.load_step_dataset(args.data_dir, split="test") # Run training phases if args.phase in ["prm", "all"]: trainer.train_prm(train_dataset, val_dataset) if args.phase in ["rl", "all"]: trainer.train_rl(train_dataset) if args.phase in ["inference", "all"]: trainer.evaluate_inference_scaling(test_dataset) if __name__ == "__main__": main()