| | """ |
| | Process Reward Model (PRM) for evaluating step-level reasoning quality. |
| | |
| | The PRM assigns rewards to intermediate reasoning steps, enabling fine-grained |
| | reinforcement learning and quality assessment. |
| | """ |
| |
|
| | from dataclasses import dataclass |
| | from typing import Dict, List, Optional, Tuple |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from transformers import PreTrainedModel |
| | import logging |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | @dataclass |
| | class StepQualityMetrics: |
| | """Metrics for assessing reasoning step quality.""" |
| | coherence: float |
| | relevance: float |
| | correctness: float |
| | informativeness: float |
| | confidence: float |
| | |
| | def to_reward(self, weights: Optional[Dict[str, float]] = None) -> float: |
| | """ |
| | Convert metrics to single reward score. |
| | |
| | Args: |
| | weights: Optional custom weights for each metric |
| | |
| | Returns: |
| | Reward in [-1, 1] range |
| | """ |
| | if weights is None: |
| | weights = { |
| | 'coherence': 0.25, |
| | 'relevance': 0.25, |
| | 'correctness': 0.3, |
| | 'informativeness': 0.15, |
| | 'confidence': 0.05, |
| | } |
| | |
| | reward = sum( |
| | getattr(self, key) * weight |
| | for key, weight in weights.items() |
| | ) |
| | |
| | |
| | return reward * 2.0 - 1.0 |
| |
|
| |
|
| | class ProcessRewardModel(nn.Module): |
| | """ |
| | Process Reward Model for step-level reasoning evaluation. |
| | |
| | Architecture: |
| | - Vision encoder: Extract visual features |
| | - Text encoder: Encode reasoning steps |
| | - Context aggregator: Combine previous steps |
| | - Reward head: Predict step quality metrics |
| | """ |
| | |
| | def __init__( |
| | self, |
| | vision_dim: int = 768, |
| | text_dim: int = 768, |
| | hidden_dim: int = 512, |
| | num_heads: int = 8, |
| | dropout: float = 0.1, |
| | max_steps: int = 10, |
| | ): |
| | """ |
| | Initialize PRM. |
| | |
| | Args: |
| | vision_dim: Dimension of vision features |
| | text_dim: Dimension of text embeddings |
| | hidden_dim: Hidden dimension for processing |
| | num_heads: Number of attention heads |
| | dropout: Dropout rate |
| | max_steps: Maximum reasoning steps to handle |
| | """ |
| | super().__init__() |
| | |
| | self.vision_dim = vision_dim |
| | self.text_dim = text_dim |
| | self.hidden_dim = hidden_dim |
| | self.max_steps = max_steps |
| | |
| | |
| | self.vision_proj = nn.Linear(vision_dim, hidden_dim) |
| | self.text_proj = nn.Linear(text_dim, hidden_dim) |
| | |
| | |
| | self.context_attention = nn.MultiheadAttention( |
| | embed_dim=hidden_dim, |
| | num_heads=num_heads, |
| | dropout=dropout, |
| | batch_first=True, |
| | ) |
| | |
| | |
| | self.step_encoder = nn.TransformerEncoder( |
| | nn.TransformerEncoderLayer( |
| | d_model=hidden_dim, |
| | nhead=num_heads, |
| | dim_feedforward=hidden_dim * 4, |
| | dropout=dropout, |
| | batch_first=True, |
| | ), |
| | num_layers=2, |
| | ) |
| | |
| | |
| | self.coherence_head = nn.Sequential( |
| | nn.Linear(hidden_dim, hidden_dim // 2), |
| | nn.ReLU(), |
| | nn.Dropout(dropout), |
| | nn.Linear(hidden_dim // 2, 1), |
| | nn.Sigmoid(), |
| | ) |
| | |
| | self.relevance_head = nn.Sequential( |
| | nn.Linear(hidden_dim, hidden_dim // 2), |
| | nn.ReLU(), |
| | nn.Dropout(dropout), |
| | nn.Linear(hidden_dim // 2, 1), |
| | nn.Sigmoid(), |
| | ) |
| | |
| | self.correctness_head = nn.Sequential( |
| | nn.Linear(hidden_dim, hidden_dim // 2), |
| | nn.ReLU(), |
| | nn.Dropout(dropout), |
| | nn.Linear(hidden_dim // 2, 1), |
| | nn.Sigmoid(), |
| | ) |
| | |
| | self.informativeness_head = nn.Sequential( |
| | nn.Linear(hidden_dim, hidden_dim // 2), |
| | nn.ReLU(), |
| | nn.Dropout(dropout), |
| | nn.Linear(hidden_dim // 2, 1), |
| | nn.Sigmoid(), |
| | ) |
| | |
| | |
| | self.confidence_head = nn.Sequential( |
| | nn.Linear(hidden_dim, hidden_dim // 2), |
| | nn.ReLU(), |
| | nn.Dropout(dropout), |
| | nn.Linear(hidden_dim // 2, 1), |
| | nn.Sigmoid(), |
| | ) |
| | |
| | |
| | self.reward_head = nn.Sequential( |
| | nn.Linear(hidden_dim, hidden_dim // 2), |
| | nn.ReLU(), |
| | nn.Dropout(dropout), |
| | nn.Linear(hidden_dim // 2, 1), |
| | nn.Tanh(), |
| | ) |
| | |
| | self.dropout = nn.Dropout(dropout) |
| | |
| | def forward( |
| | self, |
| | vision_features: torch.Tensor, |
| | step_embeddings: torch.Tensor, |
| | step_mask: Optional[torch.Tensor] = None, |
| | return_metrics: bool = False, |
| | ) -> Tuple[torch.Tensor, Optional[List[StepQualityMetrics]]]: |
| | """ |
| | Forward pass to compute rewards for reasoning steps. |
| | |
| | Args: |
| | vision_features: Visual features [batch, vision_dim] |
| | step_embeddings: Text embeddings for each step [batch, num_steps, text_dim] |
| | step_mask: Mask for valid steps [batch, num_steps] |
| | return_metrics: If True, return detailed metrics |
| | |
| | Returns: |
| | - rewards: Reward for each step [batch, num_steps] |
| | - metrics: Optional list of StepQualityMetrics per step |
| | """ |
| | batch_size, num_steps, _ = step_embeddings.shape |
| | |
| | |
| | vision_feat = self.vision_proj(vision_features) |
| | text_feat = self.text_proj(step_embeddings) |
| | |
| | |
| | vision_expanded = vision_feat.unsqueeze(1).expand(-1, num_steps, -1) |
| | |
| | |
| | combined = vision_expanded + text_feat |
| | combined = self.dropout(combined) |
| | |
| | |
| | if step_mask is not None: |
| | |
| | causal_mask = torch.triu( |
| | torch.ones(num_steps, num_steps, device=combined.device), |
| | diagonal=1 |
| | ).bool() |
| | else: |
| | causal_mask = None |
| | |
| | attended, _ = self.context_attention( |
| | combined, combined, combined, |
| | attn_mask=causal_mask, |
| | key_padding_mask=~step_mask if step_mask is not None else None, |
| | ) |
| | |
| | |
| | encoded = self.step_encoder(attended) |
| | |
| | |
| | rewards = self.reward_head(encoded).squeeze(-1) |
| | |
| | if step_mask is not None: |
| | rewards = rewards.masked_fill(~step_mask, 0.0) |
| | |
| | |
| | metrics = None |
| | if return_metrics: |
| | metrics = [] |
| | for b in range(batch_size): |
| | step_metrics = [] |
| | for s in range(num_steps): |
| | if step_mask is None or step_mask[b, s]: |
| | feat = encoded[b, s:s+1] |
| | |
| | quality = StepQualityMetrics( |
| | coherence=self.coherence_head(feat).item(), |
| | relevance=self.relevance_head(feat).item(), |
| | correctness=self.correctness_head(feat).item(), |
| | informativeness=self.informativeness_head(feat).item(), |
| | confidence=self.confidence_head(feat).item(), |
| | ) |
| | step_metrics.append(quality) |
| | metrics.append(step_metrics) |
| | |
| | return rewards, metrics |
| | |
| | def compute_step_reward( |
| | self, |
| | vision_features: torch.Tensor, |
| | current_step: torch.Tensor, |
| | previous_steps: Optional[torch.Tensor] = None, |
| | return_metrics: bool = False, |
| | ) -> Tuple[float, Optional[StepQualityMetrics]]: |
| | """ |
| | Compute reward for a single reasoning step. |
| | |
| | Args: |
| | vision_features: Visual features [vision_dim] |
| | current_step: Current step embedding [text_dim] |
| | previous_steps: Previous step embeddings [num_prev_steps, text_dim] |
| | return_metrics: If True, return detailed metrics |
| | |
| | Returns: |
| | - reward: Scalar reward in [-1, 1] |
| | - metrics: Optional StepQualityMetrics |
| | """ |
| | |
| | vision_features = vision_features.unsqueeze(0) |
| | |
| | if previous_steps is not None: |
| | |
| | steps = torch.cat([previous_steps, current_step.unsqueeze(0)], dim=0) |
| | else: |
| | steps = current_step.unsqueeze(0) |
| | |
| | steps = steps.unsqueeze(0) |
| | |
| | |
| | with torch.no_grad(): |
| | rewards, metrics = self.forward( |
| | vision_features, |
| | steps, |
| | return_metrics=return_metrics, |
| | ) |
| | |
| | |
| | reward = rewards[0, -1].item() |
| | step_metrics = metrics[0][-1] if metrics else None |
| | |
| | return reward, step_metrics |
| | |
| | @torch.no_grad() |
| | def evaluate_chain( |
| | self, |
| | vision_features: torch.Tensor, |
| | step_embeddings: List[torch.Tensor], |
| | ) -> Tuple[List[float], float]: |
| | """ |
| | Evaluate a complete reasoning chain. |
| | |
| | Args: |
| | vision_features: Visual features [vision_dim] |
| | step_embeddings: List of step embeddings |
| | |
| | Returns: |
| | - step_rewards: Reward for each step |
| | - total_reward: Sum of all step rewards |
| | """ |
| | if not step_embeddings: |
| | return [], 0.0 |
| | |
| | |
| | steps = torch.stack(step_embeddings).unsqueeze(0) |
| | vision_features = vision_features.unsqueeze(0) |
| | |
| | |
| | rewards, _ = self.forward(vision_features, steps) |
| | step_rewards = rewards[0].tolist() |
| | total_reward = sum(step_rewards) |
| | |
| | return step_rewards, total_reward |
| |
|
| |
|
| | class PRMTrainer: |
| | """Trainer for Process Reward Model using step-level supervision.""" |
| | |
| | def __init__( |
| | self, |
| | model: ProcessRewardModel, |
| | learning_rate: float = 1e-4, |
| | weight_decay: float = 0.01, |
| | warmup_steps: int = 1000, |
| | ): |
| | """ |
| | Initialize PRM trainer. |
| | |
| | Args: |
| | model: Process Reward Model |
| | learning_rate: Learning rate |
| | weight_decay: Weight decay for regularization |
| | warmup_steps: Number of warmup steps |
| | """ |
| | self.model = model |
| | self.optimizer = torch.optim.AdamW( |
| | model.parameters(), |
| | lr=learning_rate, |
| | weight_decay=weight_decay, |
| | ) |
| | self.warmup_steps = warmup_steps |
| | self.current_step = 0 |
| | |
| | def train_step( |
| | self, |
| | vision_features: torch.Tensor, |
| | step_embeddings: torch.Tensor, |
| | target_rewards: torch.Tensor, |
| | step_mask: Optional[torch.Tensor] = None, |
| | ) -> Dict[str, float]: |
| | """ |
| | Single training step. |
| | |
| | Args: |
| | vision_features: Visual features [batch, vision_dim] |
| | step_embeddings: Step embeddings [batch, num_steps, text_dim] |
| | target_rewards: Ground truth rewards [batch, num_steps] |
| | step_mask: Valid step mask [batch, num_steps] |
| | |
| | Returns: |
| | Dictionary of losses |
| | """ |
| | self.model.train() |
| | self.optimizer.zero_grad() |
| | |
| | |
| | predicted_rewards, _ = self.model( |
| | vision_features, |
| | step_embeddings, |
| | step_mask, |
| | ) |
| | |
| | |
| | if step_mask is not None: |
| | |
| | loss = F.mse_loss( |
| | predicted_rewards.masked_select(step_mask), |
| | target_rewards.masked_select(step_mask), |
| | ) |
| | else: |
| | loss = F.mse_loss(predicted_rewards, target_rewards) |
| | |
| | |
| | loss.backward() |
| | torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) |
| | self.optimizer.step() |
| | |
| | |
| | if self.current_step < self.warmup_steps: |
| | lr_scale = min(1.0, self.current_step / self.warmup_steps) |
| | for param_group in self.optimizer.param_groups: |
| | param_group['lr'] = learning_rate * lr_scale |
| | |
| | self.current_step += 1 |
| | |
| | return { |
| | 'loss': loss.item(), |
| | 'avg_predicted_reward': predicted_rewards.mean().item(), |
| | 'avg_target_reward': target_rewards.mean().item(), |
| | } |
| | |
| | @torch.no_grad() |
| | def evaluate( |
| | self, |
| | vision_features: torch.Tensor, |
| | step_embeddings: torch.Tensor, |
| | target_rewards: torch.Tensor, |
| | step_mask: Optional[torch.Tensor] = None, |
| | ) -> Dict[str, float]: |
| | """ |
| | Evaluation step. |
| | |
| | Returns: |
| | Dictionary of evaluation metrics |
| | """ |
| | self.model.eval() |
| | |
| | |
| | predicted_rewards, _ = self.model( |
| | vision_features, |
| | step_embeddings, |
| | step_mask, |
| | ) |
| | |
| | |
| | if step_mask is not None: |
| | valid_pred = predicted_rewards.masked_select(step_mask) |
| | valid_target = target_rewards.masked_select(step_mask) |
| | else: |
| | valid_pred = predicted_rewards.flatten() |
| | valid_target = target_rewards.flatten() |
| | |
| | mse = F.mse_loss(valid_pred, valid_target).item() |
| | mae = F.l1_loss(valid_pred, valid_target).item() |
| | |
| | |
| | pred_mean = valid_pred.mean() |
| | target_mean = valid_target.mean() |
| | covariance = ((valid_pred - pred_mean) * (valid_target - target_mean)).mean() |
| | pred_std = valid_pred.std() |
| | target_std = valid_target.std() |
| | correlation = covariance / (pred_std * target_std + 1e-8) |
| | |
| | return { |
| | 'mse': mse, |
| | 'mae': mae, |
| | 'correlation': correlation.item(), |
| | 'avg_predicted_reward': valid_pred.mean().item(), |
| | 'avg_target_reward': valid_target.mean().item(), |
| | } |
| |
|