""" Evaluation and ablation utilities for step-level reasoning. Provides tools for measuring: - Step quality and PRM accuracy - Reasoning chain coherence - Inference-time scaling effectiveness - Ablation studies on different components """ from typing import Dict, List, Optional, Tuple, Any import torch import numpy as np from pathlib import Path import json import matplotlib.pyplot as plt import seaborn as sns from collections import defaultdict import logging from .step_data import ReasoningChain, ReasoningStep, StepType from .prm import ProcessRewardModel, StepQualityMetrics from .inference_scaling import InferenceTimeScaling logger = logging.getLogger(__name__) class ReasoningEvaluator: """Comprehensive evaluator for step-level reasoning systems.""" def __init__( self, prm: Optional[ProcessRewardModel] = None, device: str = "cuda", ): """ Initialize evaluator. Args: prm: Process Reward Model for evaluation device: Device for computation """ self.prm = prm self.device = device if self.prm: self.prm.to(device) self.prm.eval() def evaluate_step_quality( self, chains: List[ReasoningChain], ground_truth_rewards: Optional[List[List[float]]] = None, ) -> Dict[str, float]: """ Evaluate step-level quality metrics. Args: chains: List of reasoning chains ground_truth_rewards: Optional ground truth rewards for PRM evaluation Returns: Dictionary of evaluation metrics """ all_steps = [step for chain in chains for step in chain.steps] if not all_steps: return {} # Basic statistics step_rewards = [step.reward for step in all_steps] step_confidences = [step.confidence for step in all_steps] metrics = { 'num_chains': len(chains), 'total_steps': len(all_steps), 'avg_steps_per_chain': len(all_steps) / len(chains), 'avg_step_reward': np.mean(step_rewards), 'std_step_reward': np.std(step_rewards), 'min_step_reward': np.min(step_rewards), 'max_step_reward': np.max(step_rewards), 'avg_step_confidence': np.mean(step_confidences), } # Step type distribution step_type_counts = defaultdict(int) for step in all_steps: step_type_counts[step.step_type.value] += 1 metrics['step_type_distribution'] = dict(step_type_counts) # PRM accuracy (if ground truth available) if ground_truth_rewards and self.prm: predicted_rewards = [] true_rewards = [] for chain, gt_rewards in zip(chains, ground_truth_rewards): for step, gt_reward in zip(chain.steps, gt_rewards): predicted_rewards.append(step.reward) true_rewards.append(gt_reward) if predicted_rewards and true_rewards: mse = np.mean((np.array(predicted_rewards) - np.array(true_rewards)) ** 2) mae = np.mean(np.abs(np.array(predicted_rewards) - np.array(true_rewards))) # Correlation correlation = np.corrcoef(predicted_rewards, true_rewards)[0, 1] metrics['prm_mse'] = mse metrics['prm_mae'] = mae metrics['prm_correlation'] = correlation return metrics def evaluate_chain_coherence( self, chains: List[ReasoningChain], ) -> Dict[str, float]: """ Evaluate reasoning chain coherence and structure. Args: chains: List of reasoning chains Returns: Coherence metrics """ coherence_scores = [] dependency_depths = [] for chain in chains: # Measure reward consistency (no sudden drops) if len(chain.steps) > 1: reward_diffs = [] for i in range(1, len(chain.steps)): diff = abs(chain.steps[i].reward - chain.steps[i-1].reward) reward_diffs.append(diff) # Lower variance = more coherent coherence = 1.0 / (1.0 + np.std(reward_diffs)) coherence_scores.append(coherence) # Measure dependency structure depth max_depth = 0 for step in chain.steps: if step.dependencies: depth = len(step.dependencies) max_depth = max(max_depth, depth) dependency_depths.append(max_depth) return { 'avg_coherence': np.mean(coherence_scores) if coherence_scores else 0, 'std_coherence': np.std(coherence_scores) if coherence_scores else 0, 'avg_dependency_depth': np.mean(dependency_depths) if dependency_depths else 0, 'max_dependency_depth': max(dependency_depths) if dependency_depths else 0, } def evaluate_inference_scaling( self, chains_by_sample_count: Dict[int, List[ReasoningChain]], ground_truth: List[str], ) -> Dict[str, Any]: """ Evaluate effectiveness of inference-time scaling. Args: chains_by_sample_count: Mapping from num_samples to generated chains ground_truth: Ground truth answers Returns: Scaling effectiveness metrics """ results = {} for num_samples, chains in sorted(chains_by_sample_count.items()): # Calculate accuracy correct = 0 for chain, gt in zip(chains, ground_truth): if chain.final_answer.strip().lower() == gt.strip().lower(): correct += 1 accuracy = correct / len(chains) # Average reward avg_reward = np.mean([c.total_reward for c in chains]) # Average steps avg_steps = np.mean([len(c) for c in chains]) results[num_samples] = { 'accuracy': accuracy, 'avg_reward': avg_reward, 'avg_steps': avg_steps, } # Calculate scaling benefit if len(results) > 1: sample_counts = sorted(results.keys()) baseline_acc = results[sample_counts[0]]['accuracy'] best_acc = results[sample_counts[-1]]['accuracy'] results['scaling_benefit'] = { 'accuracy_improvement': best_acc - baseline_acc, 'relative_improvement': (best_acc - baseline_acc) / baseline_acc if baseline_acc > 0 else 0, } return results def ablation_study( self, model: Any, test_chains: List[ReasoningChain], components: List[str] = ["prm", "rl", "inference_scaling"], ) -> Dict[str, Dict[str, float]]: """ Perform ablation study on different components. Args: model: Vision-language model test_chains: Test reasoning chains components: Components to ablate Returns: Ablation results for each component """ results = {} # Baseline: no reasoning if "baseline" in components: logger.info("Evaluating baseline (no reasoning)") baseline_metrics = self._evaluate_without_reasoning(model, test_chains) results['baseline'] = baseline_metrics # With PRM only if "prm" in components and self.prm: logger.info("Evaluating with PRM only") prm_metrics = self._evaluate_with_prm_only(model, test_chains) results['prm_only'] = prm_metrics # With RL only (no PRM) if "rl" in components: logger.info("Evaluating with RL only") rl_metrics = self._evaluate_with_rl_only(model, test_chains) results['rl_only'] = rl_metrics # With inference scaling only if "inference_scaling" in components: logger.info("Evaluating with inference scaling only") scaling_metrics = self._evaluate_with_scaling_only(model, test_chains) results['inference_scaling_only'] = scaling_metrics # Full system logger.info("Evaluating full system") full_metrics = self._evaluate_full_system(model, test_chains) results['full_system'] = full_metrics return results def _evaluate_without_reasoning( self, model: Any, test_chains: List[ReasoningChain], ) -> Dict[str, float]: """Baseline evaluation without step-level reasoning.""" # Implementation depends on your model # This is a placeholder return { 'accuracy': 0.0, 'avg_reward': 0.0, } def _evaluate_with_prm_only( self, model: Any, test_chains: List[ReasoningChain], ) -> Dict[str, float]: """Evaluate with PRM but no RL training.""" # Use PRM for evaluation but model trained with standard loss return { 'accuracy': 0.0, 'avg_prm_reward': 0.0, } def _evaluate_with_rl_only( self, model: Any, test_chains: List[ReasoningChain], ) -> Dict[str, float]: """Evaluate with RL but no PRM (using outcome rewards only).""" return { 'accuracy': 0.0, 'avg_reward': 0.0, } def _evaluate_with_scaling_only( self, model: Any, test_chains: List[ReasoningChain], ) -> Dict[str, float]: """Evaluate with inference scaling but no RL or PRM training.""" return { 'accuracy': 0.0, 'avg_reward': 0.0, } def _evaluate_full_system( self, model: Any, test_chains: List[ReasoningChain], ) -> Dict[str, float]: """Evaluate complete system with all components.""" return { 'accuracy': 0.0, 'avg_reward': 0.0, } def visualize_reasoning_chains( self, chains: List[ReasoningChain], save_path: Optional[str] = None, ) -> None: """ Visualize reasoning chains and step rewards. Args: chains: Reasoning chains to visualize save_path: Optional path to save figure """ fig, axes = plt.subplots(2, 2, figsize=(15, 12)) # 1. Step reward distribution all_rewards = [step.reward for chain in chains for step in chain.steps] axes[0, 0].hist(all_rewards, bins=50, edgecolor='black') axes[0, 0].set_xlabel('Step Reward') axes[0, 0].set_ylabel('Frequency') axes[0, 0].set_title('Distribution of Step Rewards') axes[0, 0].axvline(np.mean(all_rewards), color='red', linestyle='--', label='Mean') axes[0, 0].legend() # 2. Chain length vs total reward chain_lengths = [len(chain) for chain in chains] chain_rewards = [chain.total_reward for chain in chains] axes[0, 1].scatter(chain_lengths, chain_rewards, alpha=0.6) axes[0, 1].set_xlabel('Number of Steps') axes[0, 1].set_ylabel('Total Reward') axes[0, 1].set_title('Chain Length vs Total Reward') # 3. Step type distribution step_types = defaultdict(int) for chain in chains: for step in chain.steps: step_types[step.step_type.value] += 1 axes[1, 0].bar(step_types.keys(), step_types.values()) axes[1, 0].set_xlabel('Step Type') axes[1, 0].set_ylabel('Count') axes[1, 0].set_title('Step Type Distribution') axes[1, 0].tick_params(axis='x', rotation=45) # 4. Cumulative reward progression for i, chain in enumerate(chains[:10]): # Plot first 10 chains cumulative_rewards = chain.get_cumulative_rewards() axes[1, 1].plot(range(1, len(cumulative_rewards) + 1), cumulative_rewards, alpha=0.6, label=f'Chain {i+1}') axes[1, 1].set_xlabel('Step Number') axes[1, 1].set_ylabel('Cumulative Reward') axes[1, 1].set_title('Cumulative Reward Progression (First 10 Chains)') axes[1, 1].legend(bbox_to_anchor=(1.05, 1), loc='upper left') plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches='tight') logger.info(f"Visualization saved to {save_path}") else: plt.show() def generate_evaluation_report( self, chains: List[ReasoningChain], output_path: str, include_ablation: bool = False, ) -> None: """ Generate comprehensive evaluation report. Args: chains: Reasoning chains to evaluate output_path: Path to save report include_ablation: Whether to include ablation study """ report = { 'summary': { 'num_chains': len(chains), 'total_steps': sum(len(c) for c in chains), }, 'step_quality': self.evaluate_step_quality(chains), 'chain_coherence': self.evaluate_chain_coherence(chains), } # Save JSON report json_path = Path(output_path).with_suffix('.json') with open(json_path, 'w') as f: json.dump(report, f, indent=2) logger.info(f"Evaluation report saved to {json_path}") # Save visualizations viz_path = Path(output_path).with_suffix('.png') self.visualize_reasoning_chains(chains, str(viz_path)) logger.info(f"Visualizations saved to {viz_path}") def compare_models( model_chains: Dict[str, List[ReasoningChain]], ground_truth: List[str], output_dir: str, ) -> Dict[str, Dict[str, float]]: """ Compare multiple models on reasoning tasks. Args: model_chains: Mapping from model name to reasoning chains ground_truth: Ground truth answers output_dir: Directory to save comparison results Returns: Comparison metrics for each model """ output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) evaluator = ReasoningEvaluator() results = {} for model_name, chains in model_chains.items(): logger.info(f"Evaluating {model_name}") # Calculate accuracy correct = sum( 1 for chain, gt in zip(chains, ground_truth) if chain.final_answer.strip().lower() == gt.strip().lower() ) accuracy = correct / len(chains) # Get quality metrics quality_metrics = evaluator.evaluate_step_quality(chains) coherence_metrics = evaluator.evaluate_chain_coherence(chains) results[model_name] = { 'accuracy': accuracy, **quality_metrics, **coherence_metrics, } # Save comparison with open(output_path / 'model_comparison.json', 'w') as f: json.dump(results, f, indent=2) # Create comparison plots fig, axes = plt.subplots(1, 3, figsize=(18, 5)) model_names = list(results.keys()) accuracies = [results[m]['accuracy'] for m in model_names] avg_rewards = [results[m]['avg_step_reward'] for m in model_names] avg_steps = [results[m]['avg_steps_per_chain'] for m in model_names] axes[0].bar(model_names, accuracies) axes[0].set_ylabel('Accuracy') axes[0].set_title('Model Accuracy Comparison') axes[0].tick_params(axis='x', rotation=45) axes[1].bar(model_names, avg_rewards) axes[1].set_ylabel('Average Step Reward') axes[1].set_title('Average Step Reward Comparison') axes[1].tick_params(axis='x', rotation=45) axes[2].bar(model_names, avg_steps) axes[2].set_ylabel('Average Steps per Chain') axes[2].set_title('Average Chain Length Comparison') axes[2].tick_params(axis='x', rotation=45) plt.tight_layout() plt.savefig(output_path / 'model_comparison.png', dpi=300, bbox_inches='tight') logger.info(f"Model comparison saved to {output_path}") return results