| """ |
| Evaluation and ablation utilities for step-level reasoning. |
| |
| Provides tools for measuring: |
| - Step quality and PRM accuracy |
| - Reasoning chain coherence |
| - Inference-time scaling effectiveness |
| - Ablation studies on different components |
| """ |
|
|
| from typing import Dict, List, Optional, Tuple, Any |
| import torch |
| import numpy as np |
| from pathlib import Path |
| import json |
| import matplotlib.pyplot as plt |
| import seaborn as sns |
| from collections import defaultdict |
| import logging |
|
|
| from .step_data import ReasoningChain, ReasoningStep, StepType |
| from .prm import ProcessRewardModel, StepQualityMetrics |
| from .inference_scaling import InferenceTimeScaling |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class ReasoningEvaluator: |
| """Comprehensive evaluator for step-level reasoning systems.""" |
| |
| def __init__( |
| self, |
| prm: Optional[ProcessRewardModel] = None, |
| device: str = "cuda", |
| ): |
| """ |
| Initialize evaluator. |
| |
| Args: |
| prm: Process Reward Model for evaluation |
| device: Device for computation |
| """ |
| self.prm = prm |
| self.device = device |
| |
| if self.prm: |
| self.prm.to(device) |
| self.prm.eval() |
| |
| def evaluate_step_quality( |
| self, |
| chains: List[ReasoningChain], |
| ground_truth_rewards: Optional[List[List[float]]] = None, |
| ) -> Dict[str, float]: |
| """ |
| Evaluate step-level quality metrics. |
| |
| Args: |
| chains: List of reasoning chains |
| ground_truth_rewards: Optional ground truth rewards for PRM evaluation |
| |
| Returns: |
| Dictionary of evaluation metrics |
| """ |
| all_steps = [step for chain in chains for step in chain.steps] |
| |
| if not all_steps: |
| return {} |
| |
| |
| step_rewards = [step.reward for step in all_steps] |
| step_confidences = [step.confidence for step in all_steps] |
| |
| metrics = { |
| 'num_chains': len(chains), |
| 'total_steps': len(all_steps), |
| 'avg_steps_per_chain': len(all_steps) / len(chains), |
| 'avg_step_reward': np.mean(step_rewards), |
| 'std_step_reward': np.std(step_rewards), |
| 'min_step_reward': np.min(step_rewards), |
| 'max_step_reward': np.max(step_rewards), |
| 'avg_step_confidence': np.mean(step_confidences), |
| } |
| |
| |
| step_type_counts = defaultdict(int) |
| for step in all_steps: |
| step_type_counts[step.step_type.value] += 1 |
| |
| metrics['step_type_distribution'] = dict(step_type_counts) |
| |
| |
| if ground_truth_rewards and self.prm: |
| predicted_rewards = [] |
| true_rewards = [] |
| |
| for chain, gt_rewards in zip(chains, ground_truth_rewards): |
| for step, gt_reward in zip(chain.steps, gt_rewards): |
| predicted_rewards.append(step.reward) |
| true_rewards.append(gt_reward) |
| |
| if predicted_rewards and true_rewards: |
| mse = np.mean((np.array(predicted_rewards) - np.array(true_rewards)) ** 2) |
| mae = np.mean(np.abs(np.array(predicted_rewards) - np.array(true_rewards))) |
| |
| |
| correlation = np.corrcoef(predicted_rewards, true_rewards)[0, 1] |
| |
| metrics['prm_mse'] = mse |
| metrics['prm_mae'] = mae |
| metrics['prm_correlation'] = correlation |
| |
| return metrics |
| |
| def evaluate_chain_coherence( |
| self, |
| chains: List[ReasoningChain], |
| ) -> Dict[str, float]: |
| """ |
| Evaluate reasoning chain coherence and structure. |
| |
| Args: |
| chains: List of reasoning chains |
| |
| Returns: |
| Coherence metrics |
| """ |
| coherence_scores = [] |
| dependency_depths = [] |
| |
| for chain in chains: |
| |
| if len(chain.steps) > 1: |
| reward_diffs = [] |
| for i in range(1, len(chain.steps)): |
| diff = abs(chain.steps[i].reward - chain.steps[i-1].reward) |
| reward_diffs.append(diff) |
| |
| |
| coherence = 1.0 / (1.0 + np.std(reward_diffs)) |
| coherence_scores.append(coherence) |
| |
| |
| max_depth = 0 |
| for step in chain.steps: |
| if step.dependencies: |
| depth = len(step.dependencies) |
| max_depth = max(max_depth, depth) |
| dependency_depths.append(max_depth) |
| |
| return { |
| 'avg_coherence': np.mean(coherence_scores) if coherence_scores else 0, |
| 'std_coherence': np.std(coherence_scores) if coherence_scores else 0, |
| 'avg_dependency_depth': np.mean(dependency_depths) if dependency_depths else 0, |
| 'max_dependency_depth': max(dependency_depths) if dependency_depths else 0, |
| } |
| |
| def evaluate_inference_scaling( |
| self, |
| chains_by_sample_count: Dict[int, List[ReasoningChain]], |
| ground_truth: List[str], |
| ) -> Dict[str, Any]: |
| """ |
| Evaluate effectiveness of inference-time scaling. |
| |
| Args: |
| chains_by_sample_count: Mapping from num_samples to generated chains |
| ground_truth: Ground truth answers |
| |
| Returns: |
| Scaling effectiveness metrics |
| """ |
| results = {} |
| |
| for num_samples, chains in sorted(chains_by_sample_count.items()): |
| |
| correct = 0 |
| for chain, gt in zip(chains, ground_truth): |
| if chain.final_answer.strip().lower() == gt.strip().lower(): |
| correct += 1 |
| |
| accuracy = correct / len(chains) |
| |
| |
| avg_reward = np.mean([c.total_reward for c in chains]) |
| |
| |
| avg_steps = np.mean([len(c) for c in chains]) |
| |
| results[num_samples] = { |
| 'accuracy': accuracy, |
| 'avg_reward': avg_reward, |
| 'avg_steps': avg_steps, |
| } |
| |
| |
| if len(results) > 1: |
| sample_counts = sorted(results.keys()) |
| baseline_acc = results[sample_counts[0]]['accuracy'] |
| best_acc = results[sample_counts[-1]]['accuracy'] |
| |
| results['scaling_benefit'] = { |
| 'accuracy_improvement': best_acc - baseline_acc, |
| 'relative_improvement': (best_acc - baseline_acc) / baseline_acc if baseline_acc > 0 else 0, |
| } |
| |
| return results |
| |
| def ablation_study( |
| self, |
| model: Any, |
| test_chains: List[ReasoningChain], |
| components: List[str] = ["prm", "rl", "inference_scaling"], |
| ) -> Dict[str, Dict[str, float]]: |
| """ |
| Perform ablation study on different components. |
| |
| Args: |
| model: Vision-language model |
| test_chains: Test reasoning chains |
| components: Components to ablate |
| |
| Returns: |
| Ablation results for each component |
| """ |
| results = {} |
| |
| |
| if "baseline" in components: |
| logger.info("Evaluating baseline (no reasoning)") |
| baseline_metrics = self._evaluate_without_reasoning(model, test_chains) |
| results['baseline'] = baseline_metrics |
| |
| |
| if "prm" in components and self.prm: |
| logger.info("Evaluating with PRM only") |
| prm_metrics = self._evaluate_with_prm_only(model, test_chains) |
| results['prm_only'] = prm_metrics |
| |
| |
| if "rl" in components: |
| logger.info("Evaluating with RL only") |
| rl_metrics = self._evaluate_with_rl_only(model, test_chains) |
| results['rl_only'] = rl_metrics |
| |
| |
| if "inference_scaling" in components: |
| logger.info("Evaluating with inference scaling only") |
| scaling_metrics = self._evaluate_with_scaling_only(model, test_chains) |
| results['inference_scaling_only'] = scaling_metrics |
| |
| |
| logger.info("Evaluating full system") |
| full_metrics = self._evaluate_full_system(model, test_chains) |
| results['full_system'] = full_metrics |
| |
| return results |
| |
| def _evaluate_without_reasoning( |
| self, |
| model: Any, |
| test_chains: List[ReasoningChain], |
| ) -> Dict[str, float]: |
| """Baseline evaluation without step-level reasoning.""" |
| |
| |
| return { |
| 'accuracy': 0.0, |
| 'avg_reward': 0.0, |
| } |
| |
| def _evaluate_with_prm_only( |
| self, |
| model: Any, |
| test_chains: List[ReasoningChain], |
| ) -> Dict[str, float]: |
| """Evaluate with PRM but no RL training.""" |
| |
| return { |
| 'accuracy': 0.0, |
| 'avg_prm_reward': 0.0, |
| } |
| |
| def _evaluate_with_rl_only( |
| self, |
| model: Any, |
| test_chains: List[ReasoningChain], |
| ) -> Dict[str, float]: |
| """Evaluate with RL but no PRM (using outcome rewards only).""" |
| return { |
| 'accuracy': 0.0, |
| 'avg_reward': 0.0, |
| } |
| |
| def _evaluate_with_scaling_only( |
| self, |
| model: Any, |
| test_chains: List[ReasoningChain], |
| ) -> Dict[str, float]: |
| """Evaluate with inference scaling but no RL or PRM training.""" |
| return { |
| 'accuracy': 0.0, |
| 'avg_reward': 0.0, |
| } |
| |
| def _evaluate_full_system( |
| self, |
| model: Any, |
| test_chains: List[ReasoningChain], |
| ) -> Dict[str, float]: |
| """Evaluate complete system with all components.""" |
| return { |
| 'accuracy': 0.0, |
| 'avg_reward': 0.0, |
| } |
| |
| def visualize_reasoning_chains( |
| self, |
| chains: List[ReasoningChain], |
| save_path: Optional[str] = None, |
| ) -> None: |
| """ |
| Visualize reasoning chains and step rewards. |
| |
| Args: |
| chains: Reasoning chains to visualize |
| save_path: Optional path to save figure |
| """ |
| fig, axes = plt.subplots(2, 2, figsize=(15, 12)) |
| |
| |
| all_rewards = [step.reward for chain in chains for step in chain.steps] |
| axes[0, 0].hist(all_rewards, bins=50, edgecolor='black') |
| axes[0, 0].set_xlabel('Step Reward') |
| axes[0, 0].set_ylabel('Frequency') |
| axes[0, 0].set_title('Distribution of Step Rewards') |
| axes[0, 0].axvline(np.mean(all_rewards), color='red', linestyle='--', label='Mean') |
| axes[0, 0].legend() |
| |
| |
| chain_lengths = [len(chain) for chain in chains] |
| chain_rewards = [chain.total_reward for chain in chains] |
| axes[0, 1].scatter(chain_lengths, chain_rewards, alpha=0.6) |
| axes[0, 1].set_xlabel('Number of Steps') |
| axes[0, 1].set_ylabel('Total Reward') |
| axes[0, 1].set_title('Chain Length vs Total Reward') |
| |
| |
| step_types = defaultdict(int) |
| for chain in chains: |
| for step in chain.steps: |
| step_types[step.step_type.value] += 1 |
| |
| axes[1, 0].bar(step_types.keys(), step_types.values()) |
| axes[1, 0].set_xlabel('Step Type') |
| axes[1, 0].set_ylabel('Count') |
| axes[1, 0].set_title('Step Type Distribution') |
| axes[1, 0].tick_params(axis='x', rotation=45) |
| |
| |
| for i, chain in enumerate(chains[:10]): |
| cumulative_rewards = chain.get_cumulative_rewards() |
| axes[1, 1].plot(range(1, len(cumulative_rewards) + 1), cumulative_rewards, |
| alpha=0.6, label=f'Chain {i+1}') |
| |
| axes[1, 1].set_xlabel('Step Number') |
| axes[1, 1].set_ylabel('Cumulative Reward') |
| axes[1, 1].set_title('Cumulative Reward Progression (First 10 Chains)') |
| axes[1, 1].legend(bbox_to_anchor=(1.05, 1), loc='upper left') |
| |
| plt.tight_layout() |
| |
| if save_path: |
| plt.savefig(save_path, dpi=300, bbox_inches='tight') |
| logger.info(f"Visualization saved to {save_path}") |
| else: |
| plt.show() |
| |
| def generate_evaluation_report( |
| self, |
| chains: List[ReasoningChain], |
| output_path: str, |
| include_ablation: bool = False, |
| ) -> None: |
| """ |
| Generate comprehensive evaluation report. |
| |
| Args: |
| chains: Reasoning chains to evaluate |
| output_path: Path to save report |
| include_ablation: Whether to include ablation study |
| """ |
| report = { |
| 'summary': { |
| 'num_chains': len(chains), |
| 'total_steps': sum(len(c) for c in chains), |
| }, |
| 'step_quality': self.evaluate_step_quality(chains), |
| 'chain_coherence': self.evaluate_chain_coherence(chains), |
| } |
| |
| |
| json_path = Path(output_path).with_suffix('.json') |
| with open(json_path, 'w') as f: |
| json.dump(report, f, indent=2) |
| |
| logger.info(f"Evaluation report saved to {json_path}") |
| |
| |
| viz_path = Path(output_path).with_suffix('.png') |
| self.visualize_reasoning_chains(chains, str(viz_path)) |
| |
| logger.info(f"Visualizations saved to {viz_path}") |
|
|
|
|
| def compare_models( |
| model_chains: Dict[str, List[ReasoningChain]], |
| ground_truth: List[str], |
| output_dir: str, |
| ) -> Dict[str, Dict[str, float]]: |
| """ |
| Compare multiple models on reasoning tasks. |
| |
| Args: |
| model_chains: Mapping from model name to reasoning chains |
| ground_truth: Ground truth answers |
| output_dir: Directory to save comparison results |
| |
| Returns: |
| Comparison metrics for each model |
| """ |
| output_path = Path(output_dir) |
| output_path.mkdir(parents=True, exist_ok=True) |
| |
| evaluator = ReasoningEvaluator() |
| |
| results = {} |
| for model_name, chains in model_chains.items(): |
| logger.info(f"Evaluating {model_name}") |
| |
| |
| correct = sum( |
| 1 for chain, gt in zip(chains, ground_truth) |
| if chain.final_answer.strip().lower() == gt.strip().lower() |
| ) |
| accuracy = correct / len(chains) |
| |
| |
| quality_metrics = evaluator.evaluate_step_quality(chains) |
| coherence_metrics = evaluator.evaluate_chain_coherence(chains) |
| |
| results[model_name] = { |
| 'accuracy': accuracy, |
| **quality_metrics, |
| **coherence_metrics, |
| } |
| |
| |
| with open(output_path / 'model_comparison.json', 'w') as f: |
| json.dump(results, f, indent=2) |
| |
| |
| fig, axes = plt.subplots(1, 3, figsize=(18, 5)) |
| |
| model_names = list(results.keys()) |
| accuracies = [results[m]['accuracy'] for m in model_names] |
| avg_rewards = [results[m]['avg_step_reward'] for m in model_names] |
| avg_steps = [results[m]['avg_steps_per_chain'] for m in model_names] |
| |
| axes[0].bar(model_names, accuracies) |
| axes[0].set_ylabel('Accuracy') |
| axes[0].set_title('Model Accuracy Comparison') |
| axes[0].tick_params(axis='x', rotation=45) |
| |
| axes[1].bar(model_names, avg_rewards) |
| axes[1].set_ylabel('Average Step Reward') |
| axes[1].set_title('Average Step Reward Comparison') |
| axes[1].tick_params(axis='x', rotation=45) |
| |
| axes[2].bar(model_names, avg_steps) |
| axes[2].set_ylabel('Average Steps per Chain') |
| axes[2].set_title('Average Chain Length Comparison') |
| axes[2].tick_params(axis='x', rotation=45) |
| |
| plt.tight_layout() |
| plt.savefig(output_path / 'model_comparison.png', dpi=300, bbox_inches='tight') |
| |
| logger.info(f"Model comparison saved to {output_path}") |
| |
| return results |
|
|