""" Process Reward Model (PRM) for evaluating step-level reasoning quality. The PRM assigns rewards to intermediate reasoning steps, enabling fine-grained reinforcement learning and quality assessment. """ from dataclasses import dataclass from typing import Dict, List, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel import logging logger = logging.getLogger(__name__) @dataclass class StepQualityMetrics: """Metrics for assessing reasoning step quality.""" coherence: float # How well step follows from previous steps [0, 1] relevance: float # Relevance to prompt and image [0, 1] correctness: float # Logical correctness [0, 1] informativeness: float # How much new information added [0, 1] confidence: float # Model's confidence in this step [0, 1] def to_reward(self, weights: Optional[Dict[str, float]] = None) -> float: """ Convert metrics to single reward score. Args: weights: Optional custom weights for each metric Returns: Reward in [-1, 1] range """ if weights is None: weights = { 'coherence': 0.25, 'relevance': 0.25, 'correctness': 0.3, 'informativeness': 0.15, 'confidence': 0.05, } reward = sum( getattr(self, key) * weight for key, weight in weights.items() ) # Normalize to [-1, 1] (assuming metrics are [0, 1]) return reward * 2.0 - 1.0 class ProcessRewardModel(nn.Module): """ Process Reward Model for step-level reasoning evaluation. Architecture: - Vision encoder: Extract visual features - Text encoder: Encode reasoning steps - Context aggregator: Combine previous steps - Reward head: Predict step quality metrics """ def __init__( self, vision_dim: int = 768, text_dim: int = 768, hidden_dim: int = 512, num_heads: int = 8, dropout: float = 0.1, max_steps: int = 10, ): """ Initialize PRM. Args: vision_dim: Dimension of vision features text_dim: Dimension of text embeddings hidden_dim: Hidden dimension for processing num_heads: Number of attention heads dropout: Dropout rate max_steps: Maximum reasoning steps to handle """ super().__init__() self.vision_dim = vision_dim self.text_dim = text_dim self.hidden_dim = hidden_dim self.max_steps = max_steps # Vision and text projections self.vision_proj = nn.Linear(vision_dim, hidden_dim) self.text_proj = nn.Linear(text_dim, hidden_dim) # Context aggregator with self-attention over previous steps self.context_attention = nn.MultiheadAttention( embed_dim=hidden_dim, num_heads=num_heads, dropout=dropout, batch_first=True, ) # Step encoder self.step_encoder = nn.TransformerEncoder( nn.TransformerEncoderLayer( d_model=hidden_dim, nhead=num_heads, dim_feedforward=hidden_dim * 4, dropout=dropout, batch_first=True, ), num_layers=2, ) # Reward heads for different metrics self.coherence_head = nn.Sequential( nn.Linear(hidden_dim, hidden_dim // 2), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden_dim // 2, 1), nn.Sigmoid(), ) self.relevance_head = nn.Sequential( nn.Linear(hidden_dim, hidden_dim // 2), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden_dim // 2, 1), nn.Sigmoid(), ) self.correctness_head = nn.Sequential( nn.Linear(hidden_dim, hidden_dim // 2), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden_dim // 2, 1), nn.Sigmoid(), ) self.informativeness_head = nn.Sequential( nn.Linear(hidden_dim, hidden_dim // 2), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden_dim // 2, 1), nn.Sigmoid(), ) # Confidence estimation self.confidence_head = nn.Sequential( nn.Linear(hidden_dim, hidden_dim // 2), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden_dim // 2, 1), nn.Sigmoid(), ) # Overall reward head (learned combination) self.reward_head = nn.Sequential( nn.Linear(hidden_dim, hidden_dim // 2), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden_dim // 2, 1), nn.Tanh(), # Output in [-1, 1] ) self.dropout = nn.Dropout(dropout) def forward( self, vision_features: torch.Tensor, step_embeddings: torch.Tensor, step_mask: Optional[torch.Tensor] = None, return_metrics: bool = False, ) -> Tuple[torch.Tensor, Optional[List[StepQualityMetrics]]]: """ Forward pass to compute rewards for reasoning steps. Args: vision_features: Visual features [batch, vision_dim] step_embeddings: Text embeddings for each step [batch, num_steps, text_dim] step_mask: Mask for valid steps [batch, num_steps] return_metrics: If True, return detailed metrics Returns: - rewards: Reward for each step [batch, num_steps] - metrics: Optional list of StepQualityMetrics per step """ batch_size, num_steps, _ = step_embeddings.shape # Project vision and text features vision_feat = self.vision_proj(vision_features) # [batch, hidden_dim] text_feat = self.text_proj(step_embeddings) # [batch, num_steps, hidden_dim] # Expand vision features to match step dimension vision_expanded = vision_feat.unsqueeze(1).expand(-1, num_steps, -1) # Combine vision and text combined = vision_expanded + text_feat # [batch, num_steps, hidden_dim] combined = self.dropout(combined) # Apply context attention (each step attends to previous steps) if step_mask is not None: # Create causal mask for autoregressive reasoning causal_mask = torch.triu( torch.ones(num_steps, num_steps, device=combined.device), diagonal=1 ).bool() else: causal_mask = None attended, _ = self.context_attention( combined, combined, combined, attn_mask=causal_mask, key_padding_mask=~step_mask if step_mask is not None else None, ) # Encode steps encoded = self.step_encoder(attended) # [batch, num_steps, hidden_dim] # Compute rewards rewards = self.reward_head(encoded).squeeze(-1) # [batch, num_steps] if step_mask is not None: rewards = rewards.masked_fill(~step_mask, 0.0) # Optionally compute detailed metrics metrics = None if return_metrics: metrics = [] for b in range(batch_size): step_metrics = [] for s in range(num_steps): if step_mask is None or step_mask[b, s]: feat = encoded[b, s:s+1] quality = StepQualityMetrics( coherence=self.coherence_head(feat).item(), relevance=self.relevance_head(feat).item(), correctness=self.correctness_head(feat).item(), informativeness=self.informativeness_head(feat).item(), confidence=self.confidence_head(feat).item(), ) step_metrics.append(quality) metrics.append(step_metrics) return rewards, metrics def compute_step_reward( self, vision_features: torch.Tensor, current_step: torch.Tensor, previous_steps: Optional[torch.Tensor] = None, return_metrics: bool = False, ) -> Tuple[float, Optional[StepQualityMetrics]]: """ Compute reward for a single reasoning step. Args: vision_features: Visual features [vision_dim] current_step: Current step embedding [text_dim] previous_steps: Previous step embeddings [num_prev_steps, text_dim] return_metrics: If True, return detailed metrics Returns: - reward: Scalar reward in [-1, 1] - metrics: Optional StepQualityMetrics """ # Prepare batch inputs vision_features = vision_features.unsqueeze(0) # [1, vision_dim] if previous_steps is not None: # Concatenate previous and current steps = torch.cat([previous_steps, current_step.unsqueeze(0)], dim=0) else: steps = current_step.unsqueeze(0) steps = steps.unsqueeze(0) # [1, num_steps, text_dim] # Forward pass with torch.no_grad(): rewards, metrics = self.forward( vision_features, steps, return_metrics=return_metrics, ) # Extract last step reward reward = rewards[0, -1].item() step_metrics = metrics[0][-1] if metrics else None return reward, step_metrics @torch.no_grad() def evaluate_chain( self, vision_features: torch.Tensor, step_embeddings: List[torch.Tensor], ) -> Tuple[List[float], float]: """ Evaluate a complete reasoning chain. Args: vision_features: Visual features [vision_dim] step_embeddings: List of step embeddings Returns: - step_rewards: Reward for each step - total_reward: Sum of all step rewards """ if not step_embeddings: return [], 0.0 # Stack steps steps = torch.stack(step_embeddings).unsqueeze(0) # [1, num_steps, text_dim] vision_features = vision_features.unsqueeze(0) # [1, vision_dim] # Compute rewards rewards, _ = self.forward(vision_features, steps) step_rewards = rewards[0].tolist() total_reward = sum(step_rewards) return step_rewards, total_reward class PRMTrainer: """Trainer for Process Reward Model using step-level supervision.""" def __init__( self, model: ProcessRewardModel, learning_rate: float = 1e-4, weight_decay: float = 0.01, warmup_steps: int = 1000, ): """ Initialize PRM trainer. Args: model: Process Reward Model learning_rate: Learning rate weight_decay: Weight decay for regularization warmup_steps: Number of warmup steps """ self.model = model self.optimizer = torch.optim.AdamW( model.parameters(), lr=learning_rate, weight_decay=weight_decay, ) self.warmup_steps = warmup_steps self.current_step = 0 def train_step( self, vision_features: torch.Tensor, step_embeddings: torch.Tensor, target_rewards: torch.Tensor, step_mask: Optional[torch.Tensor] = None, ) -> Dict[str, float]: """ Single training step. Args: vision_features: Visual features [batch, vision_dim] step_embeddings: Step embeddings [batch, num_steps, text_dim] target_rewards: Ground truth rewards [batch, num_steps] step_mask: Valid step mask [batch, num_steps] Returns: Dictionary of losses """ self.model.train() self.optimizer.zero_grad() # Forward pass predicted_rewards, _ = self.model( vision_features, step_embeddings, step_mask, ) # Compute loss if step_mask is not None: # Masked MSE loss loss = F.mse_loss( predicted_rewards.masked_select(step_mask), target_rewards.masked_select(step_mask), ) else: loss = F.mse_loss(predicted_rewards, target_rewards) # Backward pass loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) self.optimizer.step() # Learning rate warmup if self.current_step < self.warmup_steps: lr_scale = min(1.0, self.current_step / self.warmup_steps) for param_group in self.optimizer.param_groups: param_group['lr'] = learning_rate * lr_scale self.current_step += 1 return { 'loss': loss.item(), 'avg_predicted_reward': predicted_rewards.mean().item(), 'avg_target_reward': target_rewards.mean().item(), } @torch.no_grad() def evaluate( self, vision_features: torch.Tensor, step_embeddings: torch.Tensor, target_rewards: torch.Tensor, step_mask: Optional[torch.Tensor] = None, ) -> Dict[str, float]: """ Evaluation step. Returns: Dictionary of evaluation metrics """ self.model.eval() # Forward pass predicted_rewards, _ = self.model( vision_features, step_embeddings, step_mask, ) # Compute metrics if step_mask is not None: valid_pred = predicted_rewards.masked_select(step_mask) valid_target = target_rewards.masked_select(step_mask) else: valid_pred = predicted_rewards.flatten() valid_target = target_rewards.flatten() mse = F.mse_loss(valid_pred, valid_target).item() mae = F.l1_loss(valid_pred, valid_target).item() # Correlation pred_mean = valid_pred.mean() target_mean = valid_target.mean() covariance = ((valid_pred - pred_mean) * (valid_target - target_mean)).mean() pred_std = valid_pred.std() target_std = valid_target.std() correlation = covariance / (pred_std * target_std + 1e-8) return { 'mse': mse, 'mae': mae, 'correlation': correlation.item(), 'avg_predicted_reward': valid_pred.mean().item(), 'avg_target_reward': valid_target.mean().item(), }