dei-model / src /reasoning /rl_trainer.py
renpas22
Skip .to(device) for quantized models with device_map
fa9e543
"""
Reinforcement Learning trainer for vision-language models with step-level rewards.
Implements PPO (Proximal Policy Optimization) with fine-grained process rewards
from the PRM for each reasoning step.
"""
from typing import Dict, List, Optional, Tuple, Any
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from dataclasses import dataclass
import logging
from .prm import ProcessRewardModel
from .step_data import ReasoningChain, StepDataset
logger = logging.getLogger(__name__)
@dataclass
class RLConfig:
"""Configuration for RL training."""
# PPO hyperparameters
clip_epsilon: float = 0.2
value_loss_coef: float = 0.5
entropy_coef: float = 0.01
max_grad_norm: float = 1.0
gamma: float = 0.99 # Discount factor
gae_lambda: float = 0.95 # GAE parameter
# Training settings
num_epochs: int = 4
batch_size: int = 8
learning_rate: float = 1e-5
warmup_steps: int = 500
# Step-level rewards
use_prm_rewards: bool = True
normalize_rewards: bool = True
reward_scale: float = 1.0
# Exploration
temperature: float = 1.0
top_k: int = 50
top_p: float = 0.95
class RLReasoningTrainer:
"""
Reinforcement Learning trainer for step-level reasoning.
Uses PPO to optimize the vision-language model with fine-grained rewards
from the Process Reward Model (PRM) at each reasoning step.
"""
def __init__(
self,
policy_model: nn.Module,
prm_model: ProcessRewardModel,
config: RLConfig,
device: str = "cuda",
):
"""
Initialize RL trainer.
Args:
policy_model: Vision-language model to train
prm_model: Process Reward Model for step evaluation
config: RL training configuration
device: Device for training
"""
self.policy = policy_model
self.prm = prm_model
self.config = config
self.device = device
# Move models to device (skip if already quantized with device_map)
if not (hasattr(self.policy, 'hf_device_map') or
getattr(self.policy, 'is_quantized', False)):
self.policy.to(device)
if not (hasattr(self.prm, 'hf_device_map') or
getattr(self.prm, 'is_quantized', False)):
self.prm.to(device)
# Freeze PRM (only train policy)
for param in self.prm.parameters():
param.requires_grad = False
self.prm.eval()
# Optimizer
self.optimizer = torch.optim.AdamW(
self.policy.parameters(),
lr=config.learning_rate,
betas=(0.9, 0.999),
eps=1e-8,
)
# Value head for PPO (estimates state value)
self.value_head = nn.Sequential(
nn.Linear(768, 512),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(512, 1),
).to(device)
self.value_optimizer = torch.optim.AdamW(
self.value_head.parameters(),
lr=config.learning_rate * 3,
)
self.global_step = 0
def generate_reasoning_chain(
self,
image_features: torch.Tensor,
prompt_ids: torch.Tensor,
prompt_mask: torch.Tensor,
max_steps: int = 10,
) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[float]]:
"""
Generate a reasoning chain using the current policy.
Args:
image_features: Visual features [vision_dim]
prompt_ids: Tokenized prompt [seq_len]
prompt_mask: Prompt attention mask [seq_len]
max_steps: Maximum reasoning steps
Returns:
- step_embeddings: List of step embeddings
- step_logprobs: List of log probabilities for each step
- step_entropies: List of entropies for each step
"""
self.policy.eval()
step_embeddings = []
step_logprobs = []
step_entropies = []
# Initialize with prompt
current_ids = prompt_ids.unsqueeze(0).to(self.device)
current_mask = prompt_mask.unsqueeze(0).to(self.device)
for step_idx in range(max_steps):
with torch.no_grad():
# Forward pass through policy
outputs = self.policy(
input_ids=current_ids,
attention_mask=current_mask,
output_hidden_states=True,
)
logits = outputs.logits[:, -1, :] # Next token logits
hidden_state = outputs.hidden_states[-1][:, -1, :] # Last hidden state
# Apply temperature and sampling
logits = logits / self.config.temperature
probs = F.softmax(logits, dim=-1)
# Top-k and top-p filtering
if self.config.top_k > 0:
top_k_probs, top_k_indices = torch.topk(probs, self.config.top_k, dim=-1)
probs = torch.zeros_like(probs).scatter_(1, top_k_indices, top_k_probs)
probs = probs / probs.sum(dim=-1, keepdim=True)
if self.config.top_p < 1.0:
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
# Remove tokens with cumulative probability above threshold
sorted_indices_to_remove = cumulative_probs > self.config.top_p
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = False
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
probs = probs.masked_fill(indices_to_remove, 0.0)
probs = probs / probs.sum(dim=-1, keepdim=True)
# Sample next token
next_token = torch.multinomial(probs, num_samples=1)
# Calculate log probability and entropy
log_prob = torch.log(probs.gather(1, next_token) + 1e-10)
entropy = -(probs * torch.log(probs + 1e-10)).sum(dim=-1)
# Store step information
step_embeddings.append(hidden_state.squeeze(0))
step_logprobs.append(log_prob.squeeze())
step_entropies.append(entropy.item())
# Check for stop token or step completion marker
if next_token.item() in [self.policy.config.eos_token_id, 198]: # EOS or newline
break
# Update sequence
current_ids = torch.cat([current_ids, next_token], dim=1)
current_mask = torch.cat([
current_mask,
torch.ones((1, 1), device=self.device)
], dim=1)
return step_embeddings, step_logprobs, step_entropies
def compute_advantages(
self,
rewards: torch.Tensor,
values: torch.Tensor,
dones: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute Generalized Advantage Estimation (GAE).
Args:
rewards: Step rewards [num_steps]
values: Value estimates [num_steps]
dones: Done flags [num_steps]
Returns:
- advantages: Advantage estimates [num_steps]
- returns: Discounted returns [num_steps]
"""
num_steps = len(rewards)
advantages = torch.zeros_like(rewards)
returns = torch.zeros_like(rewards)
gae = 0
next_value = 0
for t in reversed(range(num_steps)):
if t == num_steps - 1:
next_non_terminal = 1.0 - dones[t]
next_value = 0
else:
next_non_terminal = 1.0 - dones[t]
next_value = values[t + 1]
delta = rewards[t] + self.config.gamma * next_value * next_non_terminal - values[t]
gae = delta + self.config.gamma * self.config.gae_lambda * next_non_terminal * gae
advantages[t] = gae
returns[t] = gae + values[t]
return advantages, returns
def ppo_update(
self,
step_embeddings: List[torch.Tensor],
old_logprobs: torch.Tensor,
rewards: torch.Tensor,
advantages: torch.Tensor,
returns: torch.Tensor,
) -> Dict[str, float]:
"""
Perform PPO update.
Args:
step_embeddings: Hidden states for each step
old_logprobs: Old policy log probabilities
rewards: Step rewards
advantages: Advantage estimates
returns: Discounted returns
Returns:
Dictionary of training metrics
"""
self.policy.train()
self.value_head.train()
# Stack embeddings
embeddings = torch.stack(step_embeddings) # [num_steps, hidden_dim]
# Normalize advantages
if self.config.normalize_rewards:
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
total_policy_loss = 0
total_value_loss = 0
total_entropy_loss = 0
for epoch in range(self.config.num_epochs):
# Compute current policy log probabilities
# (In practice, you'd recompute through model forward pass)
# Here we use embeddings for simplification
# Compute values
values = self.value_head(embeddings).squeeze(-1)
# Value loss
value_loss = F.mse_loss(values, returns)
# For policy loss, we need to recompute logprobs
# This is simplified - in practice you'd do full forward pass
ratio = torch.exp(old_logprobs - old_logprobs) # Would be new/old
# PPO clipped objective
policy_loss_1 = -advantages * ratio
policy_loss_2 = -advantages * torch.clamp(
ratio,
1 - self.config.clip_epsilon,
1 + self.config.clip_epsilon
)
policy_loss = torch.max(policy_loss_1, policy_loss_2).mean()
# Entropy bonus for exploration
entropy_loss = -self.config.entropy_coef * old_logprobs.mean()
# Total loss
loss = (
policy_loss +
self.config.value_loss_coef * value_loss +
entropy_loss
)
# Optimization step
self.optimizer.zero_grad()
self.value_optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(
self.policy.parameters(),
self.config.max_grad_norm
)
torch.nn.utils.clip_grad_norm_(
self.value_head.parameters(),
self.config.max_grad_norm
)
self.optimizer.step()
self.value_optimizer.step()
total_policy_loss += policy_loss.item()
total_value_loss += value_loss.item()
total_entropy_loss += entropy_loss.item()
self.global_step += 1
return {
'policy_loss': total_policy_loss / self.config.num_epochs,
'value_loss': total_value_loss / self.config.num_epochs,
'entropy_loss': total_entropy_loss / self.config.num_epochs,
'avg_reward': rewards.mean().item(),
'avg_advantage': advantages.mean().item(),
'avg_return': returns.mean().item(),
}
def train_step(
self,
batch: Dict[str, torch.Tensor],
) -> Dict[str, float]:
"""
Single RL training step.
Args:
batch: Batch from StepDataset
Returns:
Training metrics
"""
# Extract batch data
image_features = batch['visual_features'].to(self.device) # [batch, vision_dim]
prompt_ids = batch['prompt_ids'].to(self.device)
prompt_mask = batch['prompt_mask'].to(self.device)
batch_size = image_features.size(0)
all_metrics = []
for b in range(batch_size):
# Generate reasoning chain with current policy
step_embeddings, step_logprobs, step_entropies = self.generate_reasoning_chain(
image_features[b],
prompt_ids[b],
prompt_mask[b],
max_steps=10,
)
if not step_embeddings:
continue
# Get PRM rewards for each step
step_rewards = []
for step_emb in step_embeddings:
reward, _ = self.prm.compute_step_reward(
image_features[b],
step_emb,
previous_steps=torch.stack(step_embeddings[:len(step_rewards)]) if step_rewards else None,
)
step_rewards.append(reward * self.config.reward_scale)
rewards = torch.tensor(step_rewards, device=self.device)
logprobs = torch.stack(step_logprobs)
# Compute value estimates
embeddings = torch.stack(step_embeddings)
with torch.no_grad():
values = self.value_head(embeddings).squeeze(-1)
# Done flags (only last step is done)
dones = torch.zeros(len(step_embeddings), device=self.device)
dones[-1] = 1.0
# Compute advantages and returns
advantages, returns = self.compute_advantages(rewards, values, dones)
# PPO update
metrics = self.ppo_update(
step_embeddings,
logprobs,
rewards,
advantages,
returns,
)
all_metrics.append(metrics)
# Average metrics across batch
if not all_metrics:
return {}
avg_metrics = {
key: sum(m[key] for m in all_metrics) / len(all_metrics)
for key in all_metrics[0].keys()
}
return avg_metrics
def train(
self,
train_dataset: StepDataset,
num_iterations: int = 1000,
log_interval: int = 10,
) -> None:
"""
Train the policy with RL.
Args:
train_dataset: Dataset with reasoning chains
num_iterations: Number of training iterations
log_interval: Logging frequency
"""
dataloader = DataLoader(
train_dataset,
batch_size=self.config.batch_size,
shuffle=True,
)
logger.info(f"Starting RL training for {num_iterations} iterations")
for iteration in range(num_iterations):
for batch in dataloader:
metrics = self.train_step(batch)
if metrics and iteration % log_interval == 0:
logger.info(
f"Iteration {iteration}: "
f"Policy Loss: {metrics['policy_loss']:.4f}, "
f"Value Loss: {metrics['value_loss']:.4f}, "
f"Avg Reward: {metrics['avg_reward']:.4f}"
)
logger.info("RL training completed")
def save_checkpoint(self, path: str) -> None:
"""Save training checkpoint."""
checkpoint = {
'policy_state_dict': self.policy.state_dict(),
'value_head_state_dict': self.value_head.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'value_optimizer_state_dict': self.value_optimizer.state_dict(),
'global_step': self.global_step,
'config': self.config,
}
torch.save(checkpoint, path)
logger.info(f"Checkpoint saved to {path}")
def load_checkpoint(self, path: str) -> None:
"""Load training checkpoint."""
checkpoint = torch.load(path, map_location=self.device)
self.policy.load_state_dict(checkpoint['policy_state_dict'])
self.value_head.load_state_dict(checkpoint['value_head_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.value_optimizer.load_state_dict(checkpoint['value_optimizer_state_dict'])
self.global_step = checkpoint['global_step']
logger.info(f"Checkpoint loaded from {path}")