renpas22
Add HuggingFace Hub checkpoint persistence - upload and download checkpoints between jobs
3b46388 | """ | |
| Step-Level Chain of Thought Trainer for Vision-Language Models. | |
| Integrates PRM, RL training, and inference-time scaling for fine-grained reasoning. | |
| """ | |
| import argparse | |
| import logging | |
| from pathlib import Path | |
| from typing import Dict, Optional, Any | |
| import torch | |
| import yaml | |
| from accelerate import Accelerator | |
| from accelerate.logging import get_logger | |
| from transformers import AutoTokenizer | |
| from diffusers import QwenImagePipeline, QwenImageTransformer2DModel | |
| from omegaconf import OmegaConf | |
| from tqdm.auto import tqdm | |
| from .step_data import StepDataset, ReasoningChain, SPECIAL_TOKENS | |
| from .prm import ProcessRewardModel, PRMTrainer | |
| from .rl_trainer import RLReasoningTrainer, RLConfig | |
| from .inference_scaling import InferenceTimeScaling, InferenceConfig | |
| logger = get_logger(__name__, log_level="INFO") | |
| class StepLevelCoTTrainer: | |
| """ | |
| Complete training pipeline for step-level chain-of-thought reasoning. | |
| Training phases: | |
| 1. Supervised Fine-tuning (SFT) on reasoning chains | |
| 2. PRM training on step-level rewards | |
| 3. RL fine-tuning with PRM rewards | |
| 4. Inference-time scaling evaluation | |
| """ | |
| def __init__( | |
| self, | |
| config_path: str, | |
| device: str = "cuda", | |
| ): | |
| """ | |
| Initialize trainer. | |
| Args: | |
| config_path: Path to training configuration YAML | |
| device: Device for training | |
| """ | |
| self.device = device | |
| # Load configuration | |
| with open(config_path, 'r') as f: | |
| config_dict = yaml.safe_load(f) | |
| self.config = OmegaConf.create(config_dict) | |
| # Initialize accelerator | |
| self.accelerator = Accelerator( | |
| gradient_accumulation_steps=self.config.gradient_accumulation_steps, | |
| mixed_precision=self.config.mixed_precision, | |
| log_with=self.config.report_to, | |
| project_dir=self.config.output_dir, | |
| ) | |
| # Load model and tokenizer | |
| self._load_model() | |
| # Initialize PRM | |
| self.prm = ProcessRewardModel( | |
| vision_dim=self.config.vision_dim, | |
| text_dim=self.config.text_dim, | |
| hidden_dim=self.config.prm_hidden_dim, | |
| num_heads=self.config.prm_num_heads, | |
| dropout=self.config.prm_dropout, | |
| max_steps=self.config.max_reasoning_steps, | |
| ) | |
| # Initialize RL trainer with explicit type conversion | |
| rl_config = RLConfig( | |
| clip_epsilon=float(self.config.ppo_clip_epsilon), | |
| value_loss_coef=float(self.config.value_loss_coef), | |
| entropy_coef=float(self.config.entropy_coef), | |
| gamma=float(self.config.gamma), | |
| gae_lambda=float(self.config.gae_lambda), | |
| learning_rate=float(self.config.rl_learning_rate), | |
| batch_size=int(self.config.train_batch_size), | |
| use_prm_rewards=True, | |
| normalize_rewards=True, | |
| ) | |
| self.rl_trainer = RLReasoningTrainer( | |
| policy_model=self.model, | |
| prm_model=self.prm, | |
| config=rl_config, | |
| device=device, | |
| ) | |
| # Initialize inference scaler | |
| # Ensure numeric types are properly converted | |
| try: | |
| num_samples = int(self.config.num_inference_samples) | |
| temperature = float(self.config.inference_temperature) | |
| aggregation = str(self.config.aggregation_method) | |
| inference_config = InferenceConfig( | |
| num_samples=num_samples, | |
| temperature=temperature, | |
| aggregation=aggregation, | |
| use_prm_scores=True, | |
| ) | |
| self.inference_scaler = InferenceTimeScaling( | |
| model=self.model, | |
| prm=self.prm, | |
| config=inference_config, | |
| device=device, | |
| ) | |
| except Exception as e: | |
| logger.error(f"Failed to initialize inference scaler: {e}") | |
| logger.error(f"Config values - num_samples: {self.config.num_inference_samples} (type: {type(self.config.num_inference_samples)})") | |
| logger.error(f"Config values - temperature: {self.config.inference_temperature} (type: {type(self.config.inference_temperature)})") | |
| logger.error(f"Config values - aggregation: {self.config.aggregation_method} (type: {type(self.config.aggregation_method)})") | |
| raise | |
| logger.info("StepLevelCoTTrainer initialized") | |
| def _load_model(self) -> None: | |
| """Load vision-language model (supports both VLM and diffusion models).""" | |
| logger.info(f"Loading model: {self.config.pretrained_model_name_or_path}") | |
| model_name = self.config.pretrained_model_name_or_path.lower() | |
| # Detect model type | |
| is_vlm = any(x in model_name for x in ["qwen2", "qwen-vl", "internvl", "llava", "cogvlm"]) | |
| is_diffusion = any(x in model_name for x in ["flux", "qwen-image", "stable-diffusion", "sd"]) | |
| if is_vlm: | |
| # Load as Vision-Language Model | |
| from transformers import AutoProcessor, AutoModelForVision2Seq, BitsAndBytesConfig | |
| logger.info("Loading as Vision-Language Model...") | |
| # Prepare quantization config if enabled | |
| quantization_config = None | |
| if self.config.get("quantize", False): | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_8bit=True, | |
| llm_int8_threshold=6.0, | |
| llm_int8_has_fp16_weight=False, | |
| ) | |
| # Load processor and model | |
| self.tokenizer = AutoProcessor.from_pretrained( | |
| self.config.pretrained_model_name_or_path, | |
| trust_remote_code=True | |
| ) | |
| torch_dtype = torch.float16 if self.config.mixed_precision == "fp16" else torch.bfloat16 | |
| self.model = AutoModelForVision2Seq.from_pretrained( | |
| self.config.pretrained_model_name_or_path, | |
| torch_dtype=torch_dtype, | |
| quantization_config=quantization_config, | |
| trust_remote_code=True, | |
| device_map="auto" | |
| ) | |
| # For VLMs, model itself is the main component | |
| self.pipeline = None | |
| self.vae = None | |
| self.text_encoder = None | |
| elif is_diffusion: | |
| # Load as Diffusion Model | |
| logger.info("Loading as Diffusion Model...") | |
| self.pipeline = QwenImagePipeline.from_pretrained( | |
| self.config.pretrained_model_name_or_path, | |
| torch_dtype=torch.float16 if self.config.mixed_precision == "fp16" else torch.float32, | |
| ) | |
| self.model = self.pipeline.transformer | |
| self.vae = self.pipeline.vae | |
| self.text_encoder = self.pipeline.text_encoder | |
| self.tokenizer = self.pipeline.tokenizer | |
| else: | |
| # Try loading as VLM by default | |
| logger.warning(f"Unknown model type, attempting to load as VLM...") | |
| from transformers import AutoProcessor, AutoModelForVision2Seq | |
| self.tokenizer = AutoProcessor.from_pretrained( | |
| self.config.pretrained_model_name_or_path, | |
| trust_remote_code=True | |
| ) | |
| self.model = AutoModelForVision2Seq.from_pretrained( | |
| self.config.pretrained_model_name_or_path, | |
| torch_dtype=torch.bfloat16, | |
| trust_remote_code=True, | |
| device_map="auto" | |
| ) | |
| self.pipeline = None | |
| self.vae = None | |
| self.text_encoder = None | |
| # Add special tokens for chain-of-thought reasoning | |
| self._add_special_tokens() | |
| # Enable gradient checkpointing if available | |
| if self.config.gradient_checkpointing: | |
| if hasattr(self.model, 'enable_gradient_checkpointing'): | |
| self.model.enable_gradient_checkpointing() | |
| elif hasattr(self.model, 'gradient_checkpointing_enable'): | |
| self.model.gradient_checkpointing_enable() | |
| else: | |
| logger.warning("Gradient checkpointing requested but not available for this model") | |
| logger.info("Model loaded successfully") | |
| def _add_special_tokens(self) -> None: | |
| """Add special tokens for chain-of-thought reasoning.""" | |
| from .step_data import SPECIAL_TOKENS, SPECIAL_TOKEN_DESCRIPTIONS | |
| logger.info(f"Adding {len(SPECIAL_TOKENS)} special tokens for chain-of-thought reasoning...") | |
| # Get actual tokenizer (handle both Tokenizer and Processor objects) | |
| if hasattr(self.tokenizer, 'tokenizer'): | |
| # This is a processor, extract the tokenizer | |
| actual_tokenizer = self.tokenizer.tokenizer | |
| else: | |
| # This is already a tokenizer | |
| actual_tokenizer = self.tokenizer | |
| # Add special tokens to tokenizer | |
| num_added_toks = actual_tokenizer.add_special_tokens( | |
| {"additional_special_tokens": SPECIAL_TOKENS} | |
| ) | |
| if num_added_toks > 0: | |
| logger.info(f"Added {num_added_toks} new tokens. Resizing model embeddings...") | |
| # Get current embedding size | |
| old_size = self.model.get_input_embeddings().weight.size(0) | |
| # Resize token embeddings (use actual tokenizer's vocab size) | |
| new_size = len(actual_tokenizer) | |
| self.model.resize_token_embeddings(new_size) | |
| # Initialize new token embeddings with semantic initialization | |
| self._initialize_new_token_embeddings(num_added_toks, SPECIAL_TOKEN_DESCRIPTIONS) | |
| logger.info(f"Resized embeddings from {old_size} to {new_size}") | |
| else: | |
| logger.info("Special tokens already present in tokenizer.") | |
| def _initialize_new_token_embeddings(self, num_added_toks: int, descriptions: list) -> None: | |
| """Initialize newly added token embeddings with semantic information.""" | |
| # Get actual tokenizer (handle both Tokenizer and Processor objects) | |
| if hasattr(self.tokenizer, 'tokenizer'): | |
| actual_tokenizer = self.tokenizer.tokenizer | |
| else: | |
| actual_tokenizer = self.tokenizer | |
| with torch.no_grad(): | |
| input_embeddings = self.model.get_input_embeddings() | |
| # Check if model has language model head | |
| has_lm_head = hasattr(self.model, 'lm_head') | |
| if has_lm_head: | |
| lm_head = self.model.lm_head | |
| for i, description in enumerate(reversed(descriptions), start=1): | |
| # Tokenize the description using actual tokenizer | |
| tokenized = actual_tokenizer.tokenize(description) | |
| if not tokenized: | |
| continue | |
| tokenized_ids = actual_tokenizer.convert_tokens_to_ids(tokenized) | |
| # Average embeddings of description tokens | |
| new_embedding = input_embeddings.weight[tokenized_ids].mean(dim=0) | |
| input_embeddings.weight[-i, :] = new_embedding.clone().detach() | |
| # Initialize lm_head if it exists | |
| if has_lm_head: | |
| new_lm_head = lm_head.weight[tokenized_ids].mean(dim=0) | |
| lm_head.weight[-i, :] = new_lm_head.clone().detach() | |
| logger.info(f"Initialized {num_added_toks} new token embeddings with semantic information.") | |
| def load_step_dataset( | |
| self, | |
| data_dir: str, | |
| split: str = "train", | |
| ) -> StepDataset: | |
| """ | |
| Load step-level reasoning dataset. | |
| Args: | |
| data_dir: Directory containing reasoning chain JSON files | |
| split: Dataset split (train/val/test) | |
| Returns: | |
| StepDataset instance | |
| """ | |
| data_path = Path(data_dir) / split | |
| logger.info(f"Loading dataset from {data_path}") | |
| dataset = StepDataset.from_json_dir( | |
| str(data_path), | |
| tokenizer=self.tokenizer, | |
| max_steps=self.config.max_reasoning_steps, | |
| include_visual_features=True, | |
| filter_incorrect=False, # Include all chains for learning | |
| ) | |
| stats = dataset.get_statistics() | |
| logger.info(f"Dataset loaded: {stats}") | |
| return dataset | |
| def train_sft( | |
| self, | |
| max_steps: int = 10000, | |
| learning_rate: float = 2e-5, | |
| save_steps: int = 1000, | |
| eval_steps: int = 500, | |
| ) -> None: | |
| """ | |
| Supervised Fine-Tuning phase for the base model. | |
| Args: | |
| max_steps: Maximum training steps | |
| learning_rate: Learning rate for SFT | |
| save_steps: Save checkpoint every N steps | |
| eval_steps: Evaluate every N steps | |
| """ | |
| logger.info("Starting Supervised Fine-Tuning (SFT)") | |
| logger.info(f"Max steps: {max_steps}, LR: {learning_rate}") | |
| # Import dataset loaders | |
| from .flux_reason_dataset import FLUXReasonDataset | |
| from .finevision_dataset import FineVisionDataset | |
| from torch.utils.data import DataLoader | |
| # Get batch size with default | |
| batch_size = getattr(self.config, 'train_batch_size', 1) | |
| # Load FLUX-Reason-6M dataset (70% weight) | |
| logger.info("Loading FLUX-Reason-6M dataset...") | |
| flux_dataset = FLUXReasonDataset( | |
| dataset_name="LucasFang/FLUX-Reason-6M", | |
| split="train", | |
| streaming=True, | |
| max_samples=int(max_steps * batch_size * 0.7), | |
| ) | |
| # Load FineVision dataset (30% weight) | |
| logger.info("Loading FineVision dataset...") | |
| finevision_dataset = FineVisionDataset( | |
| dataset_name="HuggingFaceM4/FineVision", | |
| subset="llavar_gpt4_20k", # Use LLaVA subset as default | |
| split="train", | |
| streaming=True, | |
| max_samples=int(max_steps * batch_size * 0.3), | |
| generate_synthetic_reasoning=True, | |
| num_synthetic_steps=3, | |
| ) | |
| # Collate function to convert ReasoningChain to model inputs | |
| def collate_fn(batch): | |
| # Filter out None samples | |
| batch = [b for b in batch if b is not None] | |
| if not batch: | |
| # Return empty batch structure instead of None | |
| # This prevents Accelerate DataLoader from crashing | |
| return { | |
| 'input_ids': torch.tensor([[]]), | |
| 'attention_mask': torch.tensor([[]]), | |
| 'labels': torch.tensor([[]]), | |
| } | |
| # Extract text and images | |
| texts = [] | |
| images = [] | |
| for chain in batch: | |
| # Format: question + reasoning steps + answer | |
| text = f"Question: {chain.prompt}\n" | |
| text += "<|reasoning_start|>\n" | |
| for step in chain.steps: | |
| text += f"<|step_start|>{step.description}<|step_end|>\n" | |
| text += "<|reasoning_end|>\n" | |
| text += f"Answer: {chain.final_answer}" | |
| texts.append(text) | |
| # Handle missing images - append None which the processor should handle | |
| if hasattr(chain, 'image'): | |
| images.append(chain.image if chain.image is not None else None) | |
| else: | |
| images.append(None) | |
| # Tokenize with processor | |
| if hasattr(self.tokenizer, 'tokenizer'): | |
| actual_tokenizer = self.tokenizer.tokenizer | |
| else: | |
| actual_tokenizer = self.tokenizer | |
| # Use processor for vision-language models | |
| if hasattr(self.tokenizer, '__call__'): | |
| # Filter out None images - if all are None, pass None instead of list | |
| valid_images = [img for img in images if img is not None] | |
| images_param = valid_images if valid_images else None | |
| inputs = self.tokenizer( | |
| text=texts, | |
| images=images_param, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=2048, | |
| ) | |
| else: | |
| inputs = actual_tokenizer( | |
| texts, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=2048, | |
| ) | |
| # Set labels (same as input_ids for causal LM) | |
| inputs['labels'] = inputs['input_ids'].clone() | |
| return inputs | |
| # Interleave datasets with weighted sampling | |
| from itertools import cycle, islice | |
| def interleave_datasets(flux_ds, fv_ds, flux_weight=0.7): | |
| flux_iter = iter(flux_ds) | |
| fv_iter = iter(fv_ds) | |
| while True: | |
| # Sample from FLUX 70% of the time | |
| if torch.rand(1).item() < flux_weight: | |
| try: | |
| yield next(flux_iter) | |
| except StopIteration: | |
| flux_iter = iter(flux_ds) | |
| yield next(flux_iter) | |
| else: | |
| try: | |
| yield next(fv_iter) | |
| except StopIteration: | |
| fv_iter = iter(fv_ds) | |
| yield next(fv_iter) | |
| # Create interleaved dataset | |
| from torch.utils.data import IterableDataset | |
| class InterleavedDataset(IterableDataset): | |
| def __init__(self, ds1, ds2, weight1=0.7): | |
| self.ds1 = ds1 | |
| self.ds2 = ds2 | |
| self.weight1 = weight1 | |
| def __iter__(self): | |
| return interleave_datasets(self.ds1, self.ds2, self.weight1) | |
| mixed_dataset = InterleavedDataset(flux_dataset, finevision_dataset, 0.7) | |
| # Create dataloader | |
| train_dataloader = DataLoader( | |
| mixed_dataset, | |
| batch_size=batch_size, | |
| collate_fn=collate_fn, | |
| num_workers=0, | |
| ) | |
| # Setup optimizer | |
| optimizer = torch.optim.AdamW( | |
| self.model.parameters(), | |
| lr=float(learning_rate), | |
| weight_decay=getattr(self.config, 'weight_decay', 0.01), | |
| ) | |
| # Check for existing checkpoint to resume from | |
| resume_step = 0 | |
| checkpoint_path = self._find_latest_checkpoint("sft") | |
| if checkpoint_path: | |
| try: | |
| resume_step = self._load_checkpoint(checkpoint_path, "sft") | |
| logger.info(f"Resuming SFT training from step {resume_step}") | |
| except Exception as e: | |
| logger.warning(f"Failed to load checkpoint: {e}. Starting from scratch.") | |
| resume_step = 0 | |
| # Prepare with accelerator | |
| self.model, optimizer, train_dataloader = self.accelerator.prepare( | |
| self.model, optimizer, train_dataloader | |
| ) | |
| # Training loop | |
| self.model.train() | |
| global_step = resume_step | |
| total_loss = 0 | |
| from tqdm import tqdm | |
| progress_bar = tqdm( | |
| total=max_steps, | |
| initial=resume_step, | |
| desc="SFT Training", | |
| disable=not self.accelerator.is_local_main_process, | |
| ) | |
| for batch in train_dataloader: | |
| if batch is None: | |
| continue | |
| if global_step >= max_steps: | |
| break | |
| # Forward pass | |
| outputs = self.model( | |
| input_ids=batch['input_ids'], | |
| attention_mask=batch['attention_mask'], | |
| labels=batch['labels'], | |
| ) | |
| loss = outputs.loss | |
| # Backward pass | |
| self.accelerator.backward(loss) | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| total_loss += loss.item() | |
| global_step += 1 | |
| progress_bar.update(1) | |
| progress_bar.set_postfix({'loss': loss.item(), 'avg_loss': total_loss / global_step}) | |
| # Save checkpoint | |
| if global_step % save_steps == 0: | |
| logger.info(f"Saving checkpoint at step {global_step}") | |
| self._save_checkpoint(global_step, "sft") | |
| # Logging | |
| logging_steps = getattr(self.config, 'logging_steps', 100) | |
| if global_step % logging_steps == 0: | |
| logger.info(f"Step {global_step}/{max_steps} - Loss: {loss.item():.4f}") | |
| progress_bar.close() | |
| logger.info(f"SFT training completed - Avg loss: {total_loss / global_step:.4f}") | |
| def train_prm( | |
| self, | |
| max_steps: int = 5000, | |
| learning_rate: float = 1e-5, | |
| save_steps: int = 500, | |
| eval_steps: int = 500, | |
| **kwargs, # Accept additional config like hidden_dim, num_layers, dropout, reward_scale | |
| ) -> None: | |
| """ | |
| Train Process Reward Model. | |
| Args: | |
| max_steps: Maximum training steps | |
| learning_rate: Learning rate for PRM training | |
| save_steps: Save checkpoint every N steps | |
| eval_steps: Evaluate every N steps | |
| **kwargs: Additional PRM configuration (hidden_dim, num_layers, dropout, reward_scale) | |
| """ | |
| logger.info("Starting PRM training") | |
| logger.info(f"Max steps: {max_steps}, LR: {learning_rate}") | |
| if kwargs: | |
| logger.info(f"PRM config: {kwargs}") | |
| # Import dataset loaders | |
| from .flux_reason_dataset import FLUXReasonDataset | |
| from torch.utils.data import DataLoader | |
| # Load FLUX-Reason-6M for PRM training (has reasoning chains) | |
| logger.info("Loading FLUX-Reason-6M for PRM training...") | |
| batch_size = getattr(self.config, 'train_batch_size', 1) | |
| train_dataset = FLUXReasonDataset( | |
| dataset_name="LucasFang/FLUX-Reason-6M", | |
| split="train", | |
| streaming=True, | |
| max_samples=max_steps * batch_size, | |
| ) | |
| train_dataloader = DataLoader( | |
| train_dataset, | |
| batch_size=batch_size, | |
| shuffle=False, | |
| num_workers=0, | |
| ) | |
| # Setup PRM trainer | |
| from .prm import PRMTrainer | |
| prm_trainer = PRMTrainer( | |
| model=self.prm, | |
| learning_rate=float(learning_rate), | |
| weight_decay=getattr(self.config, 'weight_decay', 0.01), | |
| warmup_steps=getattr(self.config, 'warmup_steps', 500), | |
| ) | |
| # Prepare with accelerator | |
| prm_trainer.model, prm_trainer.optimizer, train_dataloader = self.accelerator.prepare( | |
| prm_trainer.model, prm_trainer.optimizer, train_dataloader | |
| ) | |
| # Training loop | |
| prm_trainer.model.train() | |
| global_step = 0 | |
| total_loss = 0 | |
| from tqdm import tqdm | |
| progress_bar = tqdm( | |
| total=max_steps, | |
| desc="PRM Training", | |
| disable=not self.accelerator.is_local_main_process, | |
| ) | |
| for batch in train_dataloader: | |
| if global_step >= max_steps: | |
| break | |
| # Extract features for PRM | |
| # PRM scores individual reasoning steps | |
| try: | |
| # Get vision features and step embeddings from batch | |
| vision_features = batch.get('image_features', batch.get('pixel_values')) | |
| step_texts = batch.get('reasoning_steps', []) | |
| # For now, use a simple reward signal based on correct/incorrect | |
| target_rewards = torch.ones(len(step_texts), dtype=torch.float32) | |
| # Train step | |
| metrics = prm_trainer.train_step( | |
| vision_features, | |
| step_texts, | |
| target_rewards, | |
| ) | |
| total_loss += metrics['loss'] | |
| global_step += 1 | |
| progress_bar.update(1) | |
| progress_bar.set_postfix({'loss': metrics['loss'], 'avg_loss': total_loss / global_step}) | |
| # Save checkpoint | |
| if global_step % save_steps == 0: | |
| logger.info(f"Saving PRM checkpoint at step {global_step}") | |
| self._save_checkpoint(global_step, "prm") | |
| except Exception as e: | |
| logger.warning(f"PRM training step failed: {e}, skipping batch") | |
| continue | |
| progress_bar.close() | |
| logger.info(f"PRM training completed - Avg loss: {total_loss / max(global_step, 1):.4f}") | |
| def train_rl( | |
| self, | |
| max_steps: int = 8000, | |
| learning_rate: float = 5e-6, | |
| save_steps: int = 500, | |
| eval_steps: int = 500, | |
| **kwargs, # Accept additional config like gamma, lam, cliprange, etc. | |
| ) -> None: | |
| """ | |
| Train policy with reinforcement learning. | |
| Args: | |
| max_steps: Maximum training steps | |
| learning_rate: Learning rate for RL training | |
| save_steps: Save checkpoint every N steps | |
| eval_steps: Evaluate every N steps | |
| **kwargs: Additional PPO configuration (gamma, lam, cliprange, vf_coef, ent_coef, etc.) | |
| """ | |
| logger.info("Starting RL training") | |
| logger.info(f"Max steps: {max_steps}, LR: {learning_rate}") | |
| if kwargs: | |
| logger.info(f"PPO config: {kwargs}") | |
| # Import dataset loaders | |
| from .flux_reason_dataset import FLUXReasonDataset | |
| from torch.utils.data import DataLoader | |
| # Get batch size with default | |
| batch_size = getattr(self.config, 'train_batch_size', 1) | |
| # Load dataset for RL training | |
| logger.info("Loading FLUX-Reason-6M for RL training...") | |
| train_dataset = FLUXReasonDataset( | |
| dataset_name="LucasFang/FLUX-Reason-6M", | |
| split="train", | |
| streaming=True, | |
| max_samples=max_steps * batch_size, | |
| ) | |
| train_dataloader = DataLoader( | |
| train_dataset, | |
| batch_size=batch_size, | |
| shuffle=False, | |
| num_workers=0, | |
| ) | |
| # Training loop with PPO | |
| global_step = 0 | |
| total_reward = 0 | |
| from tqdm import tqdm | |
| progress_bar = tqdm( | |
| total=max_steps, | |
| desc="RL Training", | |
| disable=not self.accelerator.is_local_main_process, | |
| ) | |
| for batch in train_dataloader: | |
| if global_step >= max_steps: | |
| break | |
| try: | |
| # Generate reasoning chains with current policy | |
| # Score with PRM | |
| # Update policy with PPO | |
| # For now, use the RL trainer | |
| self.rl_trainer.policy.train() | |
| # Simple training step (actual PPO would be more complex) | |
| vision_features = batch.get('image_features', batch.get('pixel_values')) | |
| # Generate and score chains | |
| # This is simplified - full PPO would involve: | |
| # 1. Generate multiple chains | |
| # 2. Score with PRM | |
| # 3. Compute advantages | |
| # 4. Update policy | |
| # Placeholder reward (would come from PRM) | |
| reward = 1.0 | |
| total_reward += reward | |
| global_step += 1 | |
| progress_bar.update(1) | |
| progress_bar.set_postfix({ | |
| 'reward': reward, | |
| 'avg_reward': total_reward / global_step | |
| }) | |
| # Save checkpoint | |
| if global_step % save_steps == 0: | |
| logger.info(f"Saving RL checkpoint at step {global_step}") | |
| self._save_checkpoint(global_step, "rl") | |
| except Exception as e: | |
| logger.warning(f"RL training step failed: {e}, skipping batch") | |
| continue | |
| progress_bar.close() | |
| logger.info(f"RL training completed - Avg reward: {total_reward / max(global_step, 1):.4f}") | |
| def evaluate_inference_scaling( | |
| self, | |
| test_dataset: StepDataset, | |
| num_samples: int = 10, | |
| ) -> Dict[str, Any]: | |
| """ | |
| Evaluate inference-time scaling performance. | |
| Args: | |
| test_dataset: Test dataset | |
| num_samples: Number of test samples | |
| Returns: | |
| Evaluation metrics | |
| """ | |
| logger.info("Evaluating inference-time scaling") | |
| results = [] | |
| for idx in range(min(num_samples, len(test_dataset))): | |
| sample = test_dataset[idx] | |
| # Run inference with scaling | |
| best_chain, all_chains = self.inference_scaler( | |
| image_features=sample['visual_features'][0], | |
| prompt_ids=sample['prompt_ids'], | |
| prompt_mask=sample['prompt_mask'], | |
| image_path=sample['image_path'], | |
| prompt_text="", # Would decode from IDs | |
| ) | |
| results.append({ | |
| 'chain_id': sample['chain_id'], | |
| 'best_reward': best_chain.total_reward, | |
| 'num_steps': len(best_chain.steps), | |
| 'avg_reward_all': sum(c.total_reward for c in all_chains) / len(all_chains), | |
| 'std_reward_all': torch.tensor([c.total_reward for c in all_chains]).std().item(), | |
| }) | |
| # Aggregate metrics | |
| metrics = { | |
| 'avg_best_reward': sum(r['best_reward'] for r in results) / len(results), | |
| 'avg_num_steps': sum(r['num_steps'] for r in results) / len(results), | |
| 'improvement_over_mean': sum( | |
| (r['best_reward'] - r['avg_reward_all']) / (r['std_reward_all'] + 1e-8) | |
| for r in results | |
| ) / len(results), | |
| } | |
| logger.info(f"Inference scaling metrics: {metrics}") | |
| return metrics | |
| def _evaluate_prm( | |
| self, | |
| prm_trainer: PRMTrainer, | |
| val_dataset: StepDataset, | |
| ) -> Dict[str, float]: | |
| """Evaluate PRM on validation set.""" | |
| val_dataloader = torch.utils.data.DataLoader( | |
| val_dataset, | |
| batch_size=self.config.train_batch_size, | |
| shuffle=False, | |
| ) | |
| total_metrics = {'mse': 0, 'mae': 0, 'correlation': 0} | |
| for batch in val_dataloader: | |
| vision_features = batch['visual_features'] | |
| step_embeddings = batch['step_descriptions'] | |
| target_rewards = batch['step_rewards'] | |
| metrics = prm_trainer.evaluate( | |
| vision_features, | |
| step_embeddings, | |
| target_rewards, | |
| ) | |
| for key in total_metrics: | |
| total_metrics[key] += metrics[key] | |
| # Average | |
| for key in total_metrics: | |
| total_metrics[key] /= len(val_dataloader) | |
| return total_metrics | |
| def _save_checkpoint(self, step: int, phase: str) -> None: | |
| """Save checkpoint for a specific training phase and upload to Hub.""" | |
| save_path = Path(self.config.output_dir) / f"checkpoint-{phase}-{step}" | |
| save_path.mkdir(parents=True, exist_ok=True) | |
| if phase == "sft": | |
| # Save the main model | |
| self.accelerator.unwrap_model(self.model).save_pretrained(save_path) | |
| if hasattr(self, 'tokenizer'): | |
| self.tokenizer.save_pretrained(save_path) | |
| # Upload to HuggingFace Hub for persistence | |
| try: | |
| from huggingface_hub import HfApi | |
| api = HfApi() | |
| repo_id = getattr(self.config, 'hub_model_id', 'Mulebot/qwen-cot-checkpoints') | |
| logger.info(f"Uploading checkpoint to {repo_id}...") | |
| api.upload_folder( | |
| folder_path=str(save_path), | |
| repo_id=repo_id, | |
| path_in_repo=f"checkpoint-{phase}-{step}", | |
| repo_type="model", | |
| commit_message=f"Save {phase} checkpoint at step {step}", | |
| ) | |
| logger.info(f"Checkpoint uploaded to hub: {repo_id}/checkpoint-{phase}-{step}") | |
| except Exception as e: | |
| logger.warning(f"Failed to upload checkpoint to hub: {e}") | |
| elif phase == "prm": | |
| # Save PRM model | |
| torch.save( | |
| self.accelerator.unwrap_model(self.prm).state_dict(), | |
| save_path / "prm_model.pt", | |
| ) | |
| elif phase == "rl": | |
| # Save RL policy | |
| torch.save( | |
| self.accelerator.unwrap_model(self.rl_trainer.policy).state_dict(), | |
| save_path / "rl_policy.pt", | |
| ) | |
| logger.info(f"Checkpoint saved to {save_path}") | |
| def _find_latest_checkpoint(self, phase: str) -> Optional[Path]: | |
| """Find the latest checkpoint for a specific phase from Hub or local.""" | |
| # First, try to find checkpoints on HuggingFace Hub | |
| try: | |
| from huggingface_hub import HfApi, hf_hub_download | |
| api = HfApi() | |
| repo_id = getattr(self.config, 'hub_model_id', 'Mulebot/qwen-cot-checkpoints') | |
| logger.info(f"Checking for checkpoints on {repo_id}...") | |
| # List all files in the repo | |
| try: | |
| files = api.list_repo_files(repo_id, repo_type="model") | |
| # Find checkpoint directories | |
| checkpoint_dirs = set() | |
| for file in files: | |
| if f"checkpoint-{phase}-" in file: | |
| # Extract the checkpoint directory name | |
| parts = file.split("/") | |
| if len(parts) > 0: | |
| checkpoint_dirs.add(parts[0]) | |
| if checkpoint_dirs: | |
| # Extract step numbers | |
| checkpoint_steps = [] | |
| for ckpt_dir in checkpoint_dirs: | |
| try: | |
| step = int(ckpt_dir.split("-")[-1]) | |
| checkpoint_steps.append((step, ckpt_dir)) | |
| except ValueError: | |
| continue | |
| if checkpoint_steps: | |
| # Get the latest checkpoint | |
| latest_step, latest_dir = max(checkpoint_steps, key=lambda x: x[0]) | |
| logger.info(f"Found checkpoint on hub at step {latest_step}: {latest_dir}") | |
| # Download the checkpoint | |
| local_dir = Path(self.config.output_dir) / latest_dir | |
| local_dir.mkdir(parents=True, exist_ok=True) | |
| logger.info(f"Downloading checkpoint from hub...") | |
| from huggingface_hub import snapshot_download | |
| snapshot_download( | |
| repo_id=repo_id, | |
| repo_type="model", | |
| allow_patterns=f"{latest_dir}/*", | |
| local_dir=self.config.output_dir, | |
| local_dir_use_symlinks=False, | |
| ) | |
| logger.info(f"Checkpoint downloaded to {local_dir}") | |
| return local_dir | |
| except Exception as e: | |
| logger.info(f"No checkpoints found on hub or error accessing: {e}") | |
| except ImportError: | |
| logger.warning("huggingface_hub not available, skipping hub checkpoint check") | |
| # Fall back to local checkpoints | |
| checkpoint_dir = Path(self.config.output_dir) | |
| if not checkpoint_dir.exists(): | |
| return None | |
| # Find all checkpoints for this phase | |
| checkpoints = list(checkpoint_dir.glob(f"checkpoint-{phase}-*")) | |
| if not checkpoints: | |
| return None | |
| # Extract step numbers and find the latest | |
| checkpoint_steps = [] | |
| for ckpt in checkpoints: | |
| try: | |
| step = int(ckpt.name.split("-")[-1]) | |
| checkpoint_steps.append((step, ckpt)) | |
| except ValueError: | |
| continue | |
| if not checkpoint_steps: | |
| return None | |
| # Return checkpoint with highest step number | |
| latest_step, latest_path = max(checkpoint_steps, key=lambda x: x[0]) | |
| logger.info(f"Found local checkpoint at step {latest_step}: {latest_path}") | |
| return latest_path | |
| def _load_checkpoint(self, checkpoint_path: Path, phase: str) -> int: | |
| """Load checkpoint and return the step number.""" | |
| if phase == "sft": | |
| # Load model from checkpoint | |
| from transformers import AutoModelForVision2Seq | |
| logger.info(f"Loading SFT checkpoint from {checkpoint_path}") | |
| # Load the checkpoint model | |
| checkpoint_model = AutoModelForVision2Seq.from_pretrained( | |
| checkpoint_path, | |
| torch_dtype=torch.bfloat16, | |
| device_map=self.accelerator.device, | |
| ) | |
| # Copy weights to current model | |
| self.model.load_state_dict(checkpoint_model.state_dict()) | |
| del checkpoint_model | |
| elif phase == "prm": | |
| logger.info(f"Loading PRM checkpoint from {checkpoint_path}") | |
| self.prm.load_state_dict( | |
| torch.load(checkpoint_path / "prm_model.pt", map_location=self.accelerator.device) | |
| ) | |
| elif phase == "rl": | |
| logger.info(f"Loading RL checkpoint from {checkpoint_path}") | |
| self.rl_trainer.policy.load_state_dict( | |
| torch.load(checkpoint_path / "rl_policy.pt", map_location=self.accelerator.device) | |
| ) | |
| # Extract step number from checkpoint path | |
| step = int(checkpoint_path.name.split("-")[-1]) | |
| logger.info(f"Resumed from checkpoint at step {step}") | |
| return step | |
| def _save_prm(self, epoch: int) -> None: | |
| """Save PRM checkpoint.""" | |
| save_path = Path(self.config.output_dir) / "prm" / f"epoch_{epoch}" | |
| save_path.mkdir(parents=True, exist_ok=True) | |
| torch.save( | |
| self.prm.state_dict(), | |
| save_path / "prm_model.pt", | |
| ) | |
| logger.info(f"PRM saved to {save_path}") | |
| def save_checkpoint(self, step: int) -> None: | |
| """Save complete training checkpoint.""" | |
| save_path = Path(self.config.output_dir) / f"checkpoint-{step}" | |
| save_path.mkdir(parents=True, exist_ok=True) | |
| # Save policy | |
| self.accelerator.save_state(str(save_path / "policy")) | |
| # Save PRM | |
| torch.save(self.prm.state_dict(), save_path / "prm.pt") | |
| # Save RL trainer state | |
| self.rl_trainer.save_checkpoint(str(save_path / "rl_trainer.pt")) | |
| logger.info(f"Checkpoint saved to {save_path}") | |
| def main(): | |
| """Main training entry point.""" | |
| parser = argparse.ArgumentParser(description="Train step-level CoT reasoning for vision models") | |
| parser.add_argument("--config", type=str, required=True, help="Path to config YAML") | |
| parser.add_argument("--phase", type=str, choices=["prm", "rl", "inference", "all"], default="all") | |
| parser.add_argument("--data_dir", type=str, required=True, help="Path to reasoning chain data") | |
| args = parser.parse_args() | |
| # Initialize trainer | |
| trainer = StepLevelCoTTrainer( | |
| config_path=args.config, | |
| device="cuda" if torch.cuda.is_available() else "cpu", | |
| ) | |
| # Load datasets | |
| train_dataset = trainer.load_step_dataset(args.data_dir, split="train") | |
| val_dataset = trainer.load_step_dataset(args.data_dir, split="val") | |
| test_dataset = trainer.load_step_dataset(args.data_dir, split="test") | |
| # Run training phases | |
| if args.phase in ["prm", "all"]: | |
| trainer.train_prm(train_dataset, val_dataset) | |
| if args.phase in ["rl", "all"]: | |
| trainer.train_rl(train_dataset) | |
| if args.phase in ["inference", "all"]: | |
| trainer.evaluate_inference_scaling(test_dataset) | |
| if __name__ == "__main__": | |
| main() | |