renpas22
Add training scripts and configs
2b8876a
"""
Process Reward Model (PRM) for evaluating step-level reasoning quality.
The PRM assigns rewards to intermediate reasoning steps, enabling fine-grained
reinforcement learning and quality assessment.
"""
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel
import logging
logger = logging.getLogger(__name__)
@dataclass
class StepQualityMetrics:
"""Metrics for assessing reasoning step quality."""
coherence: float # How well step follows from previous steps [0, 1]
relevance: float # Relevance to prompt and image [0, 1]
correctness: float # Logical correctness [0, 1]
informativeness: float # How much new information added [0, 1]
confidence: float # Model's confidence in this step [0, 1]
def to_reward(self, weights: Optional[Dict[str, float]] = None) -> float:
"""
Convert metrics to single reward score.
Args:
weights: Optional custom weights for each metric
Returns:
Reward in [-1, 1] range
"""
if weights is None:
weights = {
'coherence': 0.25,
'relevance': 0.25,
'correctness': 0.3,
'informativeness': 0.15,
'confidence': 0.05,
}
reward = sum(
getattr(self, key) * weight
for key, weight in weights.items()
)
# Normalize to [-1, 1] (assuming metrics are [0, 1])
return reward * 2.0 - 1.0
class ProcessRewardModel(nn.Module):
"""
Process Reward Model for step-level reasoning evaluation.
Architecture:
- Vision encoder: Extract visual features
- Text encoder: Encode reasoning steps
- Context aggregator: Combine previous steps
- Reward head: Predict step quality metrics
"""
def __init__(
self,
vision_dim: int = 768,
text_dim: int = 768,
hidden_dim: int = 512,
num_heads: int = 8,
dropout: float = 0.1,
max_steps: int = 10,
):
"""
Initialize PRM.
Args:
vision_dim: Dimension of vision features
text_dim: Dimension of text embeddings
hidden_dim: Hidden dimension for processing
num_heads: Number of attention heads
dropout: Dropout rate
max_steps: Maximum reasoning steps to handle
"""
super().__init__()
self.vision_dim = vision_dim
self.text_dim = text_dim
self.hidden_dim = hidden_dim
self.max_steps = max_steps
# Vision and text projections
self.vision_proj = nn.Linear(vision_dim, hidden_dim)
self.text_proj = nn.Linear(text_dim, hidden_dim)
# Context aggregator with self-attention over previous steps
self.context_attention = nn.MultiheadAttention(
embed_dim=hidden_dim,
num_heads=num_heads,
dropout=dropout,
batch_first=True,
)
# Step encoder
self.step_encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=hidden_dim,
nhead=num_heads,
dim_feedforward=hidden_dim * 4,
dropout=dropout,
batch_first=True,
),
num_layers=2,
)
# Reward heads for different metrics
self.coherence_head = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim // 2, 1),
nn.Sigmoid(),
)
self.relevance_head = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim // 2, 1),
nn.Sigmoid(),
)
self.correctness_head = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim // 2, 1),
nn.Sigmoid(),
)
self.informativeness_head = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim // 2, 1),
nn.Sigmoid(),
)
# Confidence estimation
self.confidence_head = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim // 2, 1),
nn.Sigmoid(),
)
# Overall reward head (learned combination)
self.reward_head = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim // 2, 1),
nn.Tanh(), # Output in [-1, 1]
)
self.dropout = nn.Dropout(dropout)
def forward(
self,
vision_features: torch.Tensor,
step_embeddings: torch.Tensor,
step_mask: Optional[torch.Tensor] = None,
return_metrics: bool = False,
) -> Tuple[torch.Tensor, Optional[List[StepQualityMetrics]]]:
"""
Forward pass to compute rewards for reasoning steps.
Args:
vision_features: Visual features [batch, vision_dim]
step_embeddings: Text embeddings for each step [batch, num_steps, text_dim]
step_mask: Mask for valid steps [batch, num_steps]
return_metrics: If True, return detailed metrics
Returns:
- rewards: Reward for each step [batch, num_steps]
- metrics: Optional list of StepQualityMetrics per step
"""
batch_size, num_steps, _ = step_embeddings.shape
# Project vision and text features
vision_feat = self.vision_proj(vision_features) # [batch, hidden_dim]
text_feat = self.text_proj(step_embeddings) # [batch, num_steps, hidden_dim]
# Expand vision features to match step dimension
vision_expanded = vision_feat.unsqueeze(1).expand(-1, num_steps, -1)
# Combine vision and text
combined = vision_expanded + text_feat # [batch, num_steps, hidden_dim]
combined = self.dropout(combined)
# Apply context attention (each step attends to previous steps)
if step_mask is not None:
# Create causal mask for autoregressive reasoning
causal_mask = torch.triu(
torch.ones(num_steps, num_steps, device=combined.device),
diagonal=1
).bool()
else:
causal_mask = None
attended, _ = self.context_attention(
combined, combined, combined,
attn_mask=causal_mask,
key_padding_mask=~step_mask if step_mask is not None else None,
)
# Encode steps
encoded = self.step_encoder(attended) # [batch, num_steps, hidden_dim]
# Compute rewards
rewards = self.reward_head(encoded).squeeze(-1) # [batch, num_steps]
if step_mask is not None:
rewards = rewards.masked_fill(~step_mask, 0.0)
# Optionally compute detailed metrics
metrics = None
if return_metrics:
metrics = []
for b in range(batch_size):
step_metrics = []
for s in range(num_steps):
if step_mask is None or step_mask[b, s]:
feat = encoded[b, s:s+1]
quality = StepQualityMetrics(
coherence=self.coherence_head(feat).item(),
relevance=self.relevance_head(feat).item(),
correctness=self.correctness_head(feat).item(),
informativeness=self.informativeness_head(feat).item(),
confidence=self.confidence_head(feat).item(),
)
step_metrics.append(quality)
metrics.append(step_metrics)
return rewards, metrics
def compute_step_reward(
self,
vision_features: torch.Tensor,
current_step: torch.Tensor,
previous_steps: Optional[torch.Tensor] = None,
return_metrics: bool = False,
) -> Tuple[float, Optional[StepQualityMetrics]]:
"""
Compute reward for a single reasoning step.
Args:
vision_features: Visual features [vision_dim]
current_step: Current step embedding [text_dim]
previous_steps: Previous step embeddings [num_prev_steps, text_dim]
return_metrics: If True, return detailed metrics
Returns:
- reward: Scalar reward in [-1, 1]
- metrics: Optional StepQualityMetrics
"""
# Prepare batch inputs
vision_features = vision_features.unsqueeze(0) # [1, vision_dim]
if previous_steps is not None:
# Concatenate previous and current
steps = torch.cat([previous_steps, current_step.unsqueeze(0)], dim=0)
else:
steps = current_step.unsqueeze(0)
steps = steps.unsqueeze(0) # [1, num_steps, text_dim]
# Forward pass
with torch.no_grad():
rewards, metrics = self.forward(
vision_features,
steps,
return_metrics=return_metrics,
)
# Extract last step reward
reward = rewards[0, -1].item()
step_metrics = metrics[0][-1] if metrics else None
return reward, step_metrics
@torch.no_grad()
def evaluate_chain(
self,
vision_features: torch.Tensor,
step_embeddings: List[torch.Tensor],
) -> Tuple[List[float], float]:
"""
Evaluate a complete reasoning chain.
Args:
vision_features: Visual features [vision_dim]
step_embeddings: List of step embeddings
Returns:
- step_rewards: Reward for each step
- total_reward: Sum of all step rewards
"""
if not step_embeddings:
return [], 0.0
# Stack steps
steps = torch.stack(step_embeddings).unsqueeze(0) # [1, num_steps, text_dim]
vision_features = vision_features.unsqueeze(0) # [1, vision_dim]
# Compute rewards
rewards, _ = self.forward(vision_features, steps)
step_rewards = rewards[0].tolist()
total_reward = sum(step_rewards)
return step_rewards, total_reward
class PRMTrainer:
"""Trainer for Process Reward Model using step-level supervision."""
def __init__(
self,
model: ProcessRewardModel,
learning_rate: float = 1e-4,
weight_decay: float = 0.01,
warmup_steps: int = 1000,
):
"""
Initialize PRM trainer.
Args:
model: Process Reward Model
learning_rate: Learning rate
weight_decay: Weight decay for regularization
warmup_steps: Number of warmup steps
"""
self.model = model
self.optimizer = torch.optim.AdamW(
model.parameters(),
lr=learning_rate,
weight_decay=weight_decay,
)
self.warmup_steps = warmup_steps
self.current_step = 0
def train_step(
self,
vision_features: torch.Tensor,
step_embeddings: torch.Tensor,
target_rewards: torch.Tensor,
step_mask: Optional[torch.Tensor] = None,
) -> Dict[str, float]:
"""
Single training step.
Args:
vision_features: Visual features [batch, vision_dim]
step_embeddings: Step embeddings [batch, num_steps, text_dim]
target_rewards: Ground truth rewards [batch, num_steps]
step_mask: Valid step mask [batch, num_steps]
Returns:
Dictionary of losses
"""
self.model.train()
self.optimizer.zero_grad()
# Forward pass
predicted_rewards, _ = self.model(
vision_features,
step_embeddings,
step_mask,
)
# Compute loss
if step_mask is not None:
# Masked MSE loss
loss = F.mse_loss(
predicted_rewards.masked_select(step_mask),
target_rewards.masked_select(step_mask),
)
else:
loss = F.mse_loss(predicted_rewards, target_rewards)
# Backward pass
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
self.optimizer.step()
# Learning rate warmup
if self.current_step < self.warmup_steps:
lr_scale = min(1.0, self.current_step / self.warmup_steps)
for param_group in self.optimizer.param_groups:
param_group['lr'] = learning_rate * lr_scale
self.current_step += 1
return {
'loss': loss.item(),
'avg_predicted_reward': predicted_rewards.mean().item(),
'avg_target_reward': target_rewards.mean().item(),
}
@torch.no_grad()
def evaluate(
self,
vision_features: torch.Tensor,
step_embeddings: torch.Tensor,
target_rewards: torch.Tensor,
step_mask: Optional[torch.Tensor] = None,
) -> Dict[str, float]:
"""
Evaluation step.
Returns:
Dictionary of evaluation metrics
"""
self.model.eval()
# Forward pass
predicted_rewards, _ = self.model(
vision_features,
step_embeddings,
step_mask,
)
# Compute metrics
if step_mask is not None:
valid_pred = predicted_rewards.masked_select(step_mask)
valid_target = target_rewards.masked_select(step_mask)
else:
valid_pred = predicted_rewards.flatten()
valid_target = target_rewards.flatten()
mse = F.mse_loss(valid_pred, valid_target).item()
mae = F.l1_loss(valid_pred, valid_target).item()
# Correlation
pred_mean = valid_pred.mean()
target_mean = valid_target.mean()
covariance = ((valid_pred - pred_mean) * (valid_target - target_mean)).mean()
pred_std = valid_pred.std()
target_std = valid_target.std()
correlation = covariance / (pred_std * target_std + 1e-8)
return {
'mse': mse,
'mae': mae,
'correlation': correlation.item(),
'avg_predicted_reward': valid_pred.mean().item(),
'avg_target_reward': valid_target.mean().item(),
}