| | """ |
| | Reinforcement Learning trainer for vision-language models with step-level rewards. |
| | |
| | Implements PPO (Proximal Policy Optimization) with fine-grained process rewards |
| | from the PRM for each reasoning step. |
| | """ |
| |
|
| | from typing import Dict, List, Optional, Tuple, Any |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torch.utils.data import DataLoader |
| | from dataclasses import dataclass |
| | import logging |
| |
|
| | from .prm import ProcessRewardModel |
| | from .step_data import ReasoningChain, StepDataset |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | @dataclass |
| | class RLConfig: |
| | """Configuration for RL training.""" |
| | |
| | clip_epsilon: float = 0.2 |
| | value_loss_coef: float = 0.5 |
| | entropy_coef: float = 0.01 |
| | max_grad_norm: float = 1.0 |
| | gamma: float = 0.99 |
| | gae_lambda: float = 0.95 |
| | |
| | |
| | num_epochs: int = 4 |
| | batch_size: int = 8 |
| | learning_rate: float = 1e-5 |
| | warmup_steps: int = 500 |
| | |
| | |
| | use_prm_rewards: bool = True |
| | normalize_rewards: bool = True |
| | reward_scale: float = 1.0 |
| | |
| | |
| | temperature: float = 1.0 |
| | top_k: int = 50 |
| | top_p: float = 0.95 |
| |
|
| |
|
| | class RLReasoningTrainer: |
| | """ |
| | Reinforcement Learning trainer for step-level reasoning. |
| | |
| | Uses PPO to optimize the vision-language model with fine-grained rewards |
| | from the Process Reward Model (PRM) at each reasoning step. |
| | """ |
| | |
| | def __init__( |
| | self, |
| | policy_model: nn.Module, |
| | prm_model: ProcessRewardModel, |
| | config: RLConfig, |
| | device: str = "cuda", |
| | ): |
| | """ |
| | Initialize RL trainer. |
| | |
| | Args: |
| | policy_model: Vision-language model to train |
| | prm_model: Process Reward Model for step evaluation |
| | config: RL training configuration |
| | device: Device for training |
| | """ |
| | self.policy = policy_model |
| | self.prm = prm_model |
| | self.config = config |
| | self.device = device |
| | |
| | |
| | if not (hasattr(self.policy, 'hf_device_map') or |
| | getattr(self.policy, 'is_quantized', False)): |
| | self.policy.to(device) |
| | if not (hasattr(self.prm, 'hf_device_map') or |
| | getattr(self.prm, 'is_quantized', False)): |
| | self.prm.to(device) |
| | |
| | |
| | for param in self.prm.parameters(): |
| | param.requires_grad = False |
| | self.prm.eval() |
| | |
| | |
| | self.optimizer = torch.optim.AdamW( |
| | self.policy.parameters(), |
| | lr=config.learning_rate, |
| | betas=(0.9, 0.999), |
| | eps=1e-8, |
| | ) |
| | |
| | |
| | self.value_head = nn.Sequential( |
| | nn.Linear(768, 512), |
| | nn.ReLU(), |
| | nn.Dropout(0.1), |
| | nn.Linear(512, 1), |
| | ).to(device) |
| | |
| | self.value_optimizer = torch.optim.AdamW( |
| | self.value_head.parameters(), |
| | lr=config.learning_rate * 3, |
| | ) |
| | |
| | self.global_step = 0 |
| | |
| | def generate_reasoning_chain( |
| | self, |
| | image_features: torch.Tensor, |
| | prompt_ids: torch.Tensor, |
| | prompt_mask: torch.Tensor, |
| | max_steps: int = 10, |
| | ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[float]]: |
| | """ |
| | Generate a reasoning chain using the current policy. |
| | |
| | Args: |
| | image_features: Visual features [vision_dim] |
| | prompt_ids: Tokenized prompt [seq_len] |
| | prompt_mask: Prompt attention mask [seq_len] |
| | max_steps: Maximum reasoning steps |
| | |
| | Returns: |
| | - step_embeddings: List of step embeddings |
| | - step_logprobs: List of log probabilities for each step |
| | - step_entropies: List of entropies for each step |
| | """ |
| | self.policy.eval() |
| | |
| | step_embeddings = [] |
| | step_logprobs = [] |
| | step_entropies = [] |
| | |
| | |
| | current_ids = prompt_ids.unsqueeze(0).to(self.device) |
| | current_mask = prompt_mask.unsqueeze(0).to(self.device) |
| | |
| | for step_idx in range(max_steps): |
| | with torch.no_grad(): |
| | |
| | outputs = self.policy( |
| | input_ids=current_ids, |
| | attention_mask=current_mask, |
| | output_hidden_states=True, |
| | ) |
| | |
| | logits = outputs.logits[:, -1, :] |
| | hidden_state = outputs.hidden_states[-1][:, -1, :] |
| | |
| | |
| | logits = logits / self.config.temperature |
| | probs = F.softmax(logits, dim=-1) |
| | |
| | |
| | if self.config.top_k > 0: |
| | top_k_probs, top_k_indices = torch.topk(probs, self.config.top_k, dim=-1) |
| | probs = torch.zeros_like(probs).scatter_(1, top_k_indices, top_k_probs) |
| | probs = probs / probs.sum(dim=-1, keepdim=True) |
| | |
| | if self.config.top_p < 1.0: |
| | sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1) |
| | cumulative_probs = torch.cumsum(sorted_probs, dim=-1) |
| | |
| | |
| | sorted_indices_to_remove = cumulative_probs > self.config.top_p |
| | sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() |
| | sorted_indices_to_remove[:, 0] = False |
| | |
| | indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) |
| | probs = probs.masked_fill(indices_to_remove, 0.0) |
| | probs = probs / probs.sum(dim=-1, keepdim=True) |
| | |
| | |
| | next_token = torch.multinomial(probs, num_samples=1) |
| | |
| | |
| | log_prob = torch.log(probs.gather(1, next_token) + 1e-10) |
| | entropy = -(probs * torch.log(probs + 1e-10)).sum(dim=-1) |
| | |
| | |
| | step_embeddings.append(hidden_state.squeeze(0)) |
| | step_logprobs.append(log_prob.squeeze()) |
| | step_entropies.append(entropy.item()) |
| | |
| | |
| | if next_token.item() in [self.policy.config.eos_token_id, 198]: |
| | break |
| | |
| | |
| | current_ids = torch.cat([current_ids, next_token], dim=1) |
| | current_mask = torch.cat([ |
| | current_mask, |
| | torch.ones((1, 1), device=self.device) |
| | ], dim=1) |
| | |
| | return step_embeddings, step_logprobs, step_entropies |
| | |
| | def compute_advantages( |
| | self, |
| | rewards: torch.Tensor, |
| | values: torch.Tensor, |
| | dones: torch.Tensor, |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """ |
| | Compute Generalized Advantage Estimation (GAE). |
| | |
| | Args: |
| | rewards: Step rewards [num_steps] |
| | values: Value estimates [num_steps] |
| | dones: Done flags [num_steps] |
| | |
| | Returns: |
| | - advantages: Advantage estimates [num_steps] |
| | - returns: Discounted returns [num_steps] |
| | """ |
| | num_steps = len(rewards) |
| | advantages = torch.zeros_like(rewards) |
| | returns = torch.zeros_like(rewards) |
| | |
| | gae = 0 |
| | next_value = 0 |
| | |
| | for t in reversed(range(num_steps)): |
| | if t == num_steps - 1: |
| | next_non_terminal = 1.0 - dones[t] |
| | next_value = 0 |
| | else: |
| | next_non_terminal = 1.0 - dones[t] |
| | next_value = values[t + 1] |
| | |
| | delta = rewards[t] + self.config.gamma * next_value * next_non_terminal - values[t] |
| | gae = delta + self.config.gamma * self.config.gae_lambda * next_non_terminal * gae |
| | |
| | advantages[t] = gae |
| | returns[t] = gae + values[t] |
| | |
| | return advantages, returns |
| | |
| | def ppo_update( |
| | self, |
| | step_embeddings: List[torch.Tensor], |
| | old_logprobs: torch.Tensor, |
| | rewards: torch.Tensor, |
| | advantages: torch.Tensor, |
| | returns: torch.Tensor, |
| | ) -> Dict[str, float]: |
| | """ |
| | Perform PPO update. |
| | |
| | Args: |
| | step_embeddings: Hidden states for each step |
| | old_logprobs: Old policy log probabilities |
| | rewards: Step rewards |
| | advantages: Advantage estimates |
| | returns: Discounted returns |
| | |
| | Returns: |
| | Dictionary of training metrics |
| | """ |
| | self.policy.train() |
| | self.value_head.train() |
| | |
| | |
| | embeddings = torch.stack(step_embeddings) |
| | |
| | |
| | if self.config.normalize_rewards: |
| | advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) |
| | |
| | total_policy_loss = 0 |
| | total_value_loss = 0 |
| | total_entropy_loss = 0 |
| | |
| | for epoch in range(self.config.num_epochs): |
| | |
| | |
| | |
| | |
| | |
| | values = self.value_head(embeddings).squeeze(-1) |
| | |
| | |
| | value_loss = F.mse_loss(values, returns) |
| | |
| | |
| | |
| | ratio = torch.exp(old_logprobs - old_logprobs) |
| | |
| | |
| | policy_loss_1 = -advantages * ratio |
| | policy_loss_2 = -advantages * torch.clamp( |
| | ratio, |
| | 1 - self.config.clip_epsilon, |
| | 1 + self.config.clip_epsilon |
| | ) |
| | policy_loss = torch.max(policy_loss_1, policy_loss_2).mean() |
| | |
| | |
| | entropy_loss = -self.config.entropy_coef * old_logprobs.mean() |
| | |
| | |
| | loss = ( |
| | policy_loss + |
| | self.config.value_loss_coef * value_loss + |
| | entropy_loss |
| | ) |
| | |
| | |
| | self.optimizer.zero_grad() |
| | self.value_optimizer.zero_grad() |
| | loss.backward() |
| | |
| | torch.nn.utils.clip_grad_norm_( |
| | self.policy.parameters(), |
| | self.config.max_grad_norm |
| | ) |
| | torch.nn.utils.clip_grad_norm_( |
| | self.value_head.parameters(), |
| | self.config.max_grad_norm |
| | ) |
| | |
| | self.optimizer.step() |
| | self.value_optimizer.step() |
| | |
| | total_policy_loss += policy_loss.item() |
| | total_value_loss += value_loss.item() |
| | total_entropy_loss += entropy_loss.item() |
| | |
| | self.global_step += 1 |
| | |
| | return { |
| | 'policy_loss': total_policy_loss / self.config.num_epochs, |
| | 'value_loss': total_value_loss / self.config.num_epochs, |
| | 'entropy_loss': total_entropy_loss / self.config.num_epochs, |
| | 'avg_reward': rewards.mean().item(), |
| | 'avg_advantage': advantages.mean().item(), |
| | 'avg_return': returns.mean().item(), |
| | } |
| | |
| | def train_step( |
| | self, |
| | batch: Dict[str, torch.Tensor], |
| | ) -> Dict[str, float]: |
| | """ |
| | Single RL training step. |
| | |
| | Args: |
| | batch: Batch from StepDataset |
| | |
| | Returns: |
| | Training metrics |
| | """ |
| | |
| | image_features = batch['visual_features'].to(self.device) |
| | prompt_ids = batch['prompt_ids'].to(self.device) |
| | prompt_mask = batch['prompt_mask'].to(self.device) |
| | |
| | batch_size = image_features.size(0) |
| | all_metrics = [] |
| | |
| | for b in range(batch_size): |
| | |
| | step_embeddings, step_logprobs, step_entropies = self.generate_reasoning_chain( |
| | image_features[b], |
| | prompt_ids[b], |
| | prompt_mask[b], |
| | max_steps=10, |
| | ) |
| | |
| | if not step_embeddings: |
| | continue |
| | |
| | |
| | step_rewards = [] |
| | for step_emb in step_embeddings: |
| | reward, _ = self.prm.compute_step_reward( |
| | image_features[b], |
| | step_emb, |
| | previous_steps=torch.stack(step_embeddings[:len(step_rewards)]) if step_rewards else None, |
| | ) |
| | step_rewards.append(reward * self.config.reward_scale) |
| | |
| | rewards = torch.tensor(step_rewards, device=self.device) |
| | logprobs = torch.stack(step_logprobs) |
| | |
| | |
| | embeddings = torch.stack(step_embeddings) |
| | with torch.no_grad(): |
| | values = self.value_head(embeddings).squeeze(-1) |
| | |
| | |
| | dones = torch.zeros(len(step_embeddings), device=self.device) |
| | dones[-1] = 1.0 |
| | |
| | |
| | advantages, returns = self.compute_advantages(rewards, values, dones) |
| | |
| | |
| | metrics = self.ppo_update( |
| | step_embeddings, |
| | logprobs, |
| | rewards, |
| | advantages, |
| | returns, |
| | ) |
| | all_metrics.append(metrics) |
| | |
| | |
| | if not all_metrics: |
| | return {} |
| | |
| | avg_metrics = { |
| | key: sum(m[key] for m in all_metrics) / len(all_metrics) |
| | for key in all_metrics[0].keys() |
| | } |
| | |
| | return avg_metrics |
| | |
| | def train( |
| | self, |
| | train_dataset: StepDataset, |
| | num_iterations: int = 1000, |
| | log_interval: int = 10, |
| | ) -> None: |
| | """ |
| | Train the policy with RL. |
| | |
| | Args: |
| | train_dataset: Dataset with reasoning chains |
| | num_iterations: Number of training iterations |
| | log_interval: Logging frequency |
| | """ |
| | dataloader = DataLoader( |
| | train_dataset, |
| | batch_size=self.config.batch_size, |
| | shuffle=True, |
| | ) |
| | |
| | logger.info(f"Starting RL training for {num_iterations} iterations") |
| | |
| | for iteration in range(num_iterations): |
| | for batch in dataloader: |
| | metrics = self.train_step(batch) |
| | |
| | if metrics and iteration % log_interval == 0: |
| | logger.info( |
| | f"Iteration {iteration}: " |
| | f"Policy Loss: {metrics['policy_loss']:.4f}, " |
| | f"Value Loss: {metrics['value_loss']:.4f}, " |
| | f"Avg Reward: {metrics['avg_reward']:.4f}" |
| | ) |
| | |
| | logger.info("RL training completed") |
| | |
| | def save_checkpoint(self, path: str) -> None: |
| | """Save training checkpoint.""" |
| | checkpoint = { |
| | 'policy_state_dict': self.policy.state_dict(), |
| | 'value_head_state_dict': self.value_head.state_dict(), |
| | 'optimizer_state_dict': self.optimizer.state_dict(), |
| | 'value_optimizer_state_dict': self.value_optimizer.state_dict(), |
| | 'global_step': self.global_step, |
| | 'config': self.config, |
| | } |
| | torch.save(checkpoint, path) |
| | logger.info(f"Checkpoint saved to {path}") |
| | |
| | def load_checkpoint(self, path: str) -> None: |
| | """Load training checkpoint.""" |
| | checkpoint = torch.load(path, map_location=self.device) |
| | self.policy.load_state_dict(checkpoint['policy_state_dict']) |
| | self.value_head.load_state_dict(checkpoint['value_head_state_dict']) |
| | self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
| | self.value_optimizer.load_state_dict(checkpoint['value_optimizer_state_dict']) |
| | self.global_step = checkpoint['global_step'] |
| | logger.info(f"Checkpoint loaded from {path}") |
| |
|