""" Reinforcement Learning trainer for vision-language models with step-level rewards. Implements PPO (Proximal Policy Optimization) with fine-grained process rewards from the PRM for each reasoning step. """ from typing import Dict, List, Optional, Tuple, Any import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader from dataclasses import dataclass import logging from .prm import ProcessRewardModel from .step_data import ReasoningChain, StepDataset logger = logging.getLogger(__name__) @dataclass class RLConfig: """Configuration for RL training.""" # PPO hyperparameters clip_epsilon: float = 0.2 value_loss_coef: float = 0.5 entropy_coef: float = 0.01 max_grad_norm: float = 1.0 gamma: float = 0.99 # Discount factor gae_lambda: float = 0.95 # GAE parameter # Training settings num_epochs: int = 4 batch_size: int = 8 learning_rate: float = 1e-5 warmup_steps: int = 500 # Step-level rewards use_prm_rewards: bool = True normalize_rewards: bool = True reward_scale: float = 1.0 # Exploration temperature: float = 1.0 top_k: int = 50 top_p: float = 0.95 class RLReasoningTrainer: """ Reinforcement Learning trainer for step-level reasoning. Uses PPO to optimize the vision-language model with fine-grained rewards from the Process Reward Model (PRM) at each reasoning step. """ def __init__( self, policy_model: nn.Module, prm_model: ProcessRewardModel, config: RLConfig, device: str = "cuda", ): """ Initialize RL trainer. Args: policy_model: Vision-language model to train prm_model: Process Reward Model for step evaluation config: RL training configuration device: Device for training """ self.policy = policy_model self.prm = prm_model self.config = config self.device = device # Move models to device (skip if already quantized with device_map) if not (hasattr(self.policy, 'hf_device_map') or getattr(self.policy, 'is_quantized', False)): self.policy.to(device) if not (hasattr(self.prm, 'hf_device_map') or getattr(self.prm, 'is_quantized', False)): self.prm.to(device) # Freeze PRM (only train policy) for param in self.prm.parameters(): param.requires_grad = False self.prm.eval() # Optimizer self.optimizer = torch.optim.AdamW( self.policy.parameters(), lr=config.learning_rate, betas=(0.9, 0.999), eps=1e-8, ) # Value head for PPO (estimates state value) self.value_head = nn.Sequential( nn.Linear(768, 512), nn.ReLU(), nn.Dropout(0.1), nn.Linear(512, 1), ).to(device) self.value_optimizer = torch.optim.AdamW( self.value_head.parameters(), lr=config.learning_rate * 3, ) self.global_step = 0 def generate_reasoning_chain( self, image_features: torch.Tensor, prompt_ids: torch.Tensor, prompt_mask: torch.Tensor, max_steps: int = 10, ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[float]]: """ Generate a reasoning chain using the current policy. Args: image_features: Visual features [vision_dim] prompt_ids: Tokenized prompt [seq_len] prompt_mask: Prompt attention mask [seq_len] max_steps: Maximum reasoning steps Returns: - step_embeddings: List of step embeddings - step_logprobs: List of log probabilities for each step - step_entropies: List of entropies for each step """ self.policy.eval() step_embeddings = [] step_logprobs = [] step_entropies = [] # Initialize with prompt current_ids = prompt_ids.unsqueeze(0).to(self.device) current_mask = prompt_mask.unsqueeze(0).to(self.device) for step_idx in range(max_steps): with torch.no_grad(): # Forward pass through policy outputs = self.policy( input_ids=current_ids, attention_mask=current_mask, output_hidden_states=True, ) logits = outputs.logits[:, -1, :] # Next token logits hidden_state = outputs.hidden_states[-1][:, -1, :] # Last hidden state # Apply temperature and sampling logits = logits / self.config.temperature probs = F.softmax(logits, dim=-1) # Top-k and top-p filtering if self.config.top_k > 0: top_k_probs, top_k_indices = torch.topk(probs, self.config.top_k, dim=-1) probs = torch.zeros_like(probs).scatter_(1, top_k_indices, top_k_probs) probs = probs / probs.sum(dim=-1, keepdim=True) if self.config.top_p < 1.0: sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1) cumulative_probs = torch.cumsum(sorted_probs, dim=-1) # Remove tokens with cumulative probability above threshold sorted_indices_to_remove = cumulative_probs > self.config.top_p sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() sorted_indices_to_remove[:, 0] = False indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) probs = probs.masked_fill(indices_to_remove, 0.0) probs = probs / probs.sum(dim=-1, keepdim=True) # Sample next token next_token = torch.multinomial(probs, num_samples=1) # Calculate log probability and entropy log_prob = torch.log(probs.gather(1, next_token) + 1e-10) entropy = -(probs * torch.log(probs + 1e-10)).sum(dim=-1) # Store step information step_embeddings.append(hidden_state.squeeze(0)) step_logprobs.append(log_prob.squeeze()) step_entropies.append(entropy.item()) # Check for stop token or step completion marker if next_token.item() in [self.policy.config.eos_token_id, 198]: # EOS or newline break # Update sequence current_ids = torch.cat([current_ids, next_token], dim=1) current_mask = torch.cat([ current_mask, torch.ones((1, 1), device=self.device) ], dim=1) return step_embeddings, step_logprobs, step_entropies def compute_advantages( self, rewards: torch.Tensor, values: torch.Tensor, dones: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Compute Generalized Advantage Estimation (GAE). Args: rewards: Step rewards [num_steps] values: Value estimates [num_steps] dones: Done flags [num_steps] Returns: - advantages: Advantage estimates [num_steps] - returns: Discounted returns [num_steps] """ num_steps = len(rewards) advantages = torch.zeros_like(rewards) returns = torch.zeros_like(rewards) gae = 0 next_value = 0 for t in reversed(range(num_steps)): if t == num_steps - 1: next_non_terminal = 1.0 - dones[t] next_value = 0 else: next_non_terminal = 1.0 - dones[t] next_value = values[t + 1] delta = rewards[t] + self.config.gamma * next_value * next_non_terminal - values[t] gae = delta + self.config.gamma * self.config.gae_lambda * next_non_terminal * gae advantages[t] = gae returns[t] = gae + values[t] return advantages, returns def ppo_update( self, step_embeddings: List[torch.Tensor], old_logprobs: torch.Tensor, rewards: torch.Tensor, advantages: torch.Tensor, returns: torch.Tensor, ) -> Dict[str, float]: """ Perform PPO update. Args: step_embeddings: Hidden states for each step old_logprobs: Old policy log probabilities rewards: Step rewards advantages: Advantage estimates returns: Discounted returns Returns: Dictionary of training metrics """ self.policy.train() self.value_head.train() # Stack embeddings embeddings = torch.stack(step_embeddings) # [num_steps, hidden_dim] # Normalize advantages if self.config.normalize_rewards: advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) total_policy_loss = 0 total_value_loss = 0 total_entropy_loss = 0 for epoch in range(self.config.num_epochs): # Compute current policy log probabilities # (In practice, you'd recompute through model forward pass) # Here we use embeddings for simplification # Compute values values = self.value_head(embeddings).squeeze(-1) # Value loss value_loss = F.mse_loss(values, returns) # For policy loss, we need to recompute logprobs # This is simplified - in practice you'd do full forward pass ratio = torch.exp(old_logprobs - old_logprobs) # Would be new/old # PPO clipped objective policy_loss_1 = -advantages * ratio policy_loss_2 = -advantages * torch.clamp( ratio, 1 - self.config.clip_epsilon, 1 + self.config.clip_epsilon ) policy_loss = torch.max(policy_loss_1, policy_loss_2).mean() # Entropy bonus for exploration entropy_loss = -self.config.entropy_coef * old_logprobs.mean() # Total loss loss = ( policy_loss + self.config.value_loss_coef * value_loss + entropy_loss ) # Optimization step self.optimizer.zero_grad() self.value_optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_( self.policy.parameters(), self.config.max_grad_norm ) torch.nn.utils.clip_grad_norm_( self.value_head.parameters(), self.config.max_grad_norm ) self.optimizer.step() self.value_optimizer.step() total_policy_loss += policy_loss.item() total_value_loss += value_loss.item() total_entropy_loss += entropy_loss.item() self.global_step += 1 return { 'policy_loss': total_policy_loss / self.config.num_epochs, 'value_loss': total_value_loss / self.config.num_epochs, 'entropy_loss': total_entropy_loss / self.config.num_epochs, 'avg_reward': rewards.mean().item(), 'avg_advantage': advantages.mean().item(), 'avg_return': returns.mean().item(), } def train_step( self, batch: Dict[str, torch.Tensor], ) -> Dict[str, float]: """ Single RL training step. Args: batch: Batch from StepDataset Returns: Training metrics """ # Extract batch data image_features = batch['visual_features'].to(self.device) # [batch, vision_dim] prompt_ids = batch['prompt_ids'].to(self.device) prompt_mask = batch['prompt_mask'].to(self.device) batch_size = image_features.size(0) all_metrics = [] for b in range(batch_size): # Generate reasoning chain with current policy step_embeddings, step_logprobs, step_entropies = self.generate_reasoning_chain( image_features[b], prompt_ids[b], prompt_mask[b], max_steps=10, ) if not step_embeddings: continue # Get PRM rewards for each step step_rewards = [] for step_emb in step_embeddings: reward, _ = self.prm.compute_step_reward( image_features[b], step_emb, previous_steps=torch.stack(step_embeddings[:len(step_rewards)]) if step_rewards else None, ) step_rewards.append(reward * self.config.reward_scale) rewards = torch.tensor(step_rewards, device=self.device) logprobs = torch.stack(step_logprobs) # Compute value estimates embeddings = torch.stack(step_embeddings) with torch.no_grad(): values = self.value_head(embeddings).squeeze(-1) # Done flags (only last step is done) dones = torch.zeros(len(step_embeddings), device=self.device) dones[-1] = 1.0 # Compute advantages and returns advantages, returns = self.compute_advantages(rewards, values, dones) # PPO update metrics = self.ppo_update( step_embeddings, logprobs, rewards, advantages, returns, ) all_metrics.append(metrics) # Average metrics across batch if not all_metrics: return {} avg_metrics = { key: sum(m[key] for m in all_metrics) / len(all_metrics) for key in all_metrics[0].keys() } return avg_metrics def train( self, train_dataset: StepDataset, num_iterations: int = 1000, log_interval: int = 10, ) -> None: """ Train the policy with RL. Args: train_dataset: Dataset with reasoning chains num_iterations: Number of training iterations log_interval: Logging frequency """ dataloader = DataLoader( train_dataset, batch_size=self.config.batch_size, shuffle=True, ) logger.info(f"Starting RL training for {num_iterations} iterations") for iteration in range(num_iterations): for batch in dataloader: metrics = self.train_step(batch) if metrics and iteration % log_interval == 0: logger.info( f"Iteration {iteration}: " f"Policy Loss: {metrics['policy_loss']:.4f}, " f"Value Loss: {metrics['value_loss']:.4f}, " f"Avg Reward: {metrics['avg_reward']:.4f}" ) logger.info("RL training completed") def save_checkpoint(self, path: str) -> None: """Save training checkpoint.""" checkpoint = { 'policy_state_dict': self.policy.state_dict(), 'value_head_state_dict': self.value_head.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'value_optimizer_state_dict': self.value_optimizer.state_dict(), 'global_step': self.global_step, 'config': self.config, } torch.save(checkpoint, path) logger.info(f"Checkpoint saved to {path}") def load_checkpoint(self, path: str) -> None: """Load training checkpoint.""" checkpoint = torch.load(path, map_location=self.device) self.policy.load_state_dict(checkpoint['policy_state_dict']) self.value_head.load_state_dict(checkpoint['value_head_state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) self.value_optimizer.load_state_dict(checkpoint['value_optimizer_state_dict']) self.global_step = checkpoint['global_step'] logger.info(f"Checkpoint loaded from {path}")