dei-model / utils /reasoning_eval.py
renpas22
Add utils directory
da76488
"""
Evaluation and ablation utilities for step-level reasoning.
Provides tools for measuring:
- Step quality and PRM accuracy
- Reasoning chain coherence
- Inference-time scaling effectiveness
- Ablation studies on different components
"""
from typing import Dict, List, Optional, Tuple, Any
import torch
import numpy as np
from pathlib import Path
import json
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict
import logging
from .step_data import ReasoningChain, ReasoningStep, StepType
from .prm import ProcessRewardModel, StepQualityMetrics
from .inference_scaling import InferenceTimeScaling
logger = logging.getLogger(__name__)
class ReasoningEvaluator:
"""Comprehensive evaluator for step-level reasoning systems."""
def __init__(
self,
prm: Optional[ProcessRewardModel] = None,
device: str = "cuda",
):
"""
Initialize evaluator.
Args:
prm: Process Reward Model for evaluation
device: Device for computation
"""
self.prm = prm
self.device = device
if self.prm:
self.prm.to(device)
self.prm.eval()
def evaluate_step_quality(
self,
chains: List[ReasoningChain],
ground_truth_rewards: Optional[List[List[float]]] = None,
) -> Dict[str, float]:
"""
Evaluate step-level quality metrics.
Args:
chains: List of reasoning chains
ground_truth_rewards: Optional ground truth rewards for PRM evaluation
Returns:
Dictionary of evaluation metrics
"""
all_steps = [step for chain in chains for step in chain.steps]
if not all_steps:
return {}
# Basic statistics
step_rewards = [step.reward for step in all_steps]
step_confidences = [step.confidence for step in all_steps]
metrics = {
'num_chains': len(chains),
'total_steps': len(all_steps),
'avg_steps_per_chain': len(all_steps) / len(chains),
'avg_step_reward': np.mean(step_rewards),
'std_step_reward': np.std(step_rewards),
'min_step_reward': np.min(step_rewards),
'max_step_reward': np.max(step_rewards),
'avg_step_confidence': np.mean(step_confidences),
}
# Step type distribution
step_type_counts = defaultdict(int)
for step in all_steps:
step_type_counts[step.step_type.value] += 1
metrics['step_type_distribution'] = dict(step_type_counts)
# PRM accuracy (if ground truth available)
if ground_truth_rewards and self.prm:
predicted_rewards = []
true_rewards = []
for chain, gt_rewards in zip(chains, ground_truth_rewards):
for step, gt_reward in zip(chain.steps, gt_rewards):
predicted_rewards.append(step.reward)
true_rewards.append(gt_reward)
if predicted_rewards and true_rewards:
mse = np.mean((np.array(predicted_rewards) - np.array(true_rewards)) ** 2)
mae = np.mean(np.abs(np.array(predicted_rewards) - np.array(true_rewards)))
# Correlation
correlation = np.corrcoef(predicted_rewards, true_rewards)[0, 1]
metrics['prm_mse'] = mse
metrics['prm_mae'] = mae
metrics['prm_correlation'] = correlation
return metrics
def evaluate_chain_coherence(
self,
chains: List[ReasoningChain],
) -> Dict[str, float]:
"""
Evaluate reasoning chain coherence and structure.
Args:
chains: List of reasoning chains
Returns:
Coherence metrics
"""
coherence_scores = []
dependency_depths = []
for chain in chains:
# Measure reward consistency (no sudden drops)
if len(chain.steps) > 1:
reward_diffs = []
for i in range(1, len(chain.steps)):
diff = abs(chain.steps[i].reward - chain.steps[i-1].reward)
reward_diffs.append(diff)
# Lower variance = more coherent
coherence = 1.0 / (1.0 + np.std(reward_diffs))
coherence_scores.append(coherence)
# Measure dependency structure depth
max_depth = 0
for step in chain.steps:
if step.dependencies:
depth = len(step.dependencies)
max_depth = max(max_depth, depth)
dependency_depths.append(max_depth)
return {
'avg_coherence': np.mean(coherence_scores) if coherence_scores else 0,
'std_coherence': np.std(coherence_scores) if coherence_scores else 0,
'avg_dependency_depth': np.mean(dependency_depths) if dependency_depths else 0,
'max_dependency_depth': max(dependency_depths) if dependency_depths else 0,
}
def evaluate_inference_scaling(
self,
chains_by_sample_count: Dict[int, List[ReasoningChain]],
ground_truth: List[str],
) -> Dict[str, Any]:
"""
Evaluate effectiveness of inference-time scaling.
Args:
chains_by_sample_count: Mapping from num_samples to generated chains
ground_truth: Ground truth answers
Returns:
Scaling effectiveness metrics
"""
results = {}
for num_samples, chains in sorted(chains_by_sample_count.items()):
# Calculate accuracy
correct = 0
for chain, gt in zip(chains, ground_truth):
if chain.final_answer.strip().lower() == gt.strip().lower():
correct += 1
accuracy = correct / len(chains)
# Average reward
avg_reward = np.mean([c.total_reward for c in chains])
# Average steps
avg_steps = np.mean([len(c) for c in chains])
results[num_samples] = {
'accuracy': accuracy,
'avg_reward': avg_reward,
'avg_steps': avg_steps,
}
# Calculate scaling benefit
if len(results) > 1:
sample_counts = sorted(results.keys())
baseline_acc = results[sample_counts[0]]['accuracy']
best_acc = results[sample_counts[-1]]['accuracy']
results['scaling_benefit'] = {
'accuracy_improvement': best_acc - baseline_acc,
'relative_improvement': (best_acc - baseline_acc) / baseline_acc if baseline_acc > 0 else 0,
}
return results
def ablation_study(
self,
model: Any,
test_chains: List[ReasoningChain],
components: List[str] = ["prm", "rl", "inference_scaling"],
) -> Dict[str, Dict[str, float]]:
"""
Perform ablation study on different components.
Args:
model: Vision-language model
test_chains: Test reasoning chains
components: Components to ablate
Returns:
Ablation results for each component
"""
results = {}
# Baseline: no reasoning
if "baseline" in components:
logger.info("Evaluating baseline (no reasoning)")
baseline_metrics = self._evaluate_without_reasoning(model, test_chains)
results['baseline'] = baseline_metrics
# With PRM only
if "prm" in components and self.prm:
logger.info("Evaluating with PRM only")
prm_metrics = self._evaluate_with_prm_only(model, test_chains)
results['prm_only'] = prm_metrics
# With RL only (no PRM)
if "rl" in components:
logger.info("Evaluating with RL only")
rl_metrics = self._evaluate_with_rl_only(model, test_chains)
results['rl_only'] = rl_metrics
# With inference scaling only
if "inference_scaling" in components:
logger.info("Evaluating with inference scaling only")
scaling_metrics = self._evaluate_with_scaling_only(model, test_chains)
results['inference_scaling_only'] = scaling_metrics
# Full system
logger.info("Evaluating full system")
full_metrics = self._evaluate_full_system(model, test_chains)
results['full_system'] = full_metrics
return results
def _evaluate_without_reasoning(
self,
model: Any,
test_chains: List[ReasoningChain],
) -> Dict[str, float]:
"""Baseline evaluation without step-level reasoning."""
# Implementation depends on your model
# This is a placeholder
return {
'accuracy': 0.0,
'avg_reward': 0.0,
}
def _evaluate_with_prm_only(
self,
model: Any,
test_chains: List[ReasoningChain],
) -> Dict[str, float]:
"""Evaluate with PRM but no RL training."""
# Use PRM for evaluation but model trained with standard loss
return {
'accuracy': 0.0,
'avg_prm_reward': 0.0,
}
def _evaluate_with_rl_only(
self,
model: Any,
test_chains: List[ReasoningChain],
) -> Dict[str, float]:
"""Evaluate with RL but no PRM (using outcome rewards only)."""
return {
'accuracy': 0.0,
'avg_reward': 0.0,
}
def _evaluate_with_scaling_only(
self,
model: Any,
test_chains: List[ReasoningChain],
) -> Dict[str, float]:
"""Evaluate with inference scaling but no RL or PRM training."""
return {
'accuracy': 0.0,
'avg_reward': 0.0,
}
def _evaluate_full_system(
self,
model: Any,
test_chains: List[ReasoningChain],
) -> Dict[str, float]:
"""Evaluate complete system with all components."""
return {
'accuracy': 0.0,
'avg_reward': 0.0,
}
def visualize_reasoning_chains(
self,
chains: List[ReasoningChain],
save_path: Optional[str] = None,
) -> None:
"""
Visualize reasoning chains and step rewards.
Args:
chains: Reasoning chains to visualize
save_path: Optional path to save figure
"""
fig, axes = plt.subplots(2, 2, figsize=(15, 12))
# 1. Step reward distribution
all_rewards = [step.reward for chain in chains for step in chain.steps]
axes[0, 0].hist(all_rewards, bins=50, edgecolor='black')
axes[0, 0].set_xlabel('Step Reward')
axes[0, 0].set_ylabel('Frequency')
axes[0, 0].set_title('Distribution of Step Rewards')
axes[0, 0].axvline(np.mean(all_rewards), color='red', linestyle='--', label='Mean')
axes[0, 0].legend()
# 2. Chain length vs total reward
chain_lengths = [len(chain) for chain in chains]
chain_rewards = [chain.total_reward for chain in chains]
axes[0, 1].scatter(chain_lengths, chain_rewards, alpha=0.6)
axes[0, 1].set_xlabel('Number of Steps')
axes[0, 1].set_ylabel('Total Reward')
axes[0, 1].set_title('Chain Length vs Total Reward')
# 3. Step type distribution
step_types = defaultdict(int)
for chain in chains:
for step in chain.steps:
step_types[step.step_type.value] += 1
axes[1, 0].bar(step_types.keys(), step_types.values())
axes[1, 0].set_xlabel('Step Type')
axes[1, 0].set_ylabel('Count')
axes[1, 0].set_title('Step Type Distribution')
axes[1, 0].tick_params(axis='x', rotation=45)
# 4. Cumulative reward progression
for i, chain in enumerate(chains[:10]): # Plot first 10 chains
cumulative_rewards = chain.get_cumulative_rewards()
axes[1, 1].plot(range(1, len(cumulative_rewards) + 1), cumulative_rewards,
alpha=0.6, label=f'Chain {i+1}')
axes[1, 1].set_xlabel('Step Number')
axes[1, 1].set_ylabel('Cumulative Reward')
axes[1, 1].set_title('Cumulative Reward Progression (First 10 Chains)')
axes[1, 1].legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
logger.info(f"Visualization saved to {save_path}")
else:
plt.show()
def generate_evaluation_report(
self,
chains: List[ReasoningChain],
output_path: str,
include_ablation: bool = False,
) -> None:
"""
Generate comprehensive evaluation report.
Args:
chains: Reasoning chains to evaluate
output_path: Path to save report
include_ablation: Whether to include ablation study
"""
report = {
'summary': {
'num_chains': len(chains),
'total_steps': sum(len(c) for c in chains),
},
'step_quality': self.evaluate_step_quality(chains),
'chain_coherence': self.evaluate_chain_coherence(chains),
}
# Save JSON report
json_path = Path(output_path).with_suffix('.json')
with open(json_path, 'w') as f:
json.dump(report, f, indent=2)
logger.info(f"Evaluation report saved to {json_path}")
# Save visualizations
viz_path = Path(output_path).with_suffix('.png')
self.visualize_reasoning_chains(chains, str(viz_path))
logger.info(f"Visualizations saved to {viz_path}")
def compare_models(
model_chains: Dict[str, List[ReasoningChain]],
ground_truth: List[str],
output_dir: str,
) -> Dict[str, Dict[str, float]]:
"""
Compare multiple models on reasoning tasks.
Args:
model_chains: Mapping from model name to reasoning chains
ground_truth: Ground truth answers
output_dir: Directory to save comparison results
Returns:
Comparison metrics for each model
"""
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
evaluator = ReasoningEvaluator()
results = {}
for model_name, chains in model_chains.items():
logger.info(f"Evaluating {model_name}")
# Calculate accuracy
correct = sum(
1 for chain, gt in zip(chains, ground_truth)
if chain.final_answer.strip().lower() == gt.strip().lower()
)
accuracy = correct / len(chains)
# Get quality metrics
quality_metrics = evaluator.evaluate_step_quality(chains)
coherence_metrics = evaluator.evaluate_chain_coherence(chains)
results[model_name] = {
'accuracy': accuracy,
**quality_metrics,
**coherence_metrics,
}
# Save comparison
with open(output_path / 'model_comparison.json', 'w') as f:
json.dump(results, f, indent=2)
# Create comparison plots
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
model_names = list(results.keys())
accuracies = [results[m]['accuracy'] for m in model_names]
avg_rewards = [results[m]['avg_step_reward'] for m in model_names]
avg_steps = [results[m]['avg_steps_per_chain'] for m in model_names]
axes[0].bar(model_names, accuracies)
axes[0].set_ylabel('Accuracy')
axes[0].set_title('Model Accuracy Comparison')
axes[0].tick_params(axis='x', rotation=45)
axes[1].bar(model_names, avg_rewards)
axes[1].set_ylabel('Average Step Reward')
axes[1].set_title('Average Step Reward Comparison')
axes[1].tick_params(axis='x', rotation=45)
axes[2].bar(model_names, avg_steps)
axes[2].set_ylabel('Average Steps per Chain')
axes[2].set_title('Average Chain Length Comparison')
axes[2].tick_params(axis='x', rotation=45)
plt.tight_layout()
plt.savefig(output_path / 'model_comparison.png', dpi=300, bbox_inches='tight')
logger.info(f"Model comparison saved to {output_path}")
return results