| | """ |
| | Visualization utilities for understanding step-level reasoning. |
| | |
| | Creates diagrams and plots to explain the system architecture and results. |
| | """ |
| |
|
| | import matplotlib.pyplot as plt |
| | import matplotlib.patches as mpatches |
| | from matplotlib.patches import FancyBboxPatch, FancyArrowPatch |
| | import numpy as np |
| | from typing import List, Dict |
| | import seaborn as sns |
| |
|
| |
|
| | def visualize_architecture(save_path: str = "step_cot_architecture.png"): |
| | """Create architecture diagram for the step-level CoT system.""" |
| | fig, ax = plt.subplots(figsize=(12, 14)) |
| | ax.set_xlim(0, 10) |
| | ax.set_ylim(0, 12) |
| | ax.axis('off') |
| | |
| | |
| | colors = { |
| | 'model': '#3498db', |
| | 'reasoning': '#2ecc71', |
| | 'prm': '#e74c3c', |
| | 'rl': '#f39c12', |
| | 'inference': '#9b59b6', |
| | } |
| | |
| | |
| | ax.text(5, 11.5, 'Step-Level Chain of Thought Architecture', |
| | ha='center', fontsize=18, fontweight='bold') |
| | |
| | |
| | box1 = FancyBboxPatch((1, 9.5), 8, 1.2, boxstyle="round,pad=0.1", |
| | facecolor=colors['model'], edgecolor='black', linewidth=2) |
| | ax.add_patch(box1) |
| | ax.text(5, 10.1, 'Vision-Language Model', ha='center', va='center', |
| | fontsize=12, fontweight='bold', color='white') |
| | ax.text(5, 9.8, '(Qwen-Image / Custom VLM)', ha='center', va='center', |
| | fontsize=9, color='white') |
| | |
| | |
| | arrow1 = FancyArrowPatch((5, 9.5), (5, 8.7), arrowstyle='->', |
| | mutation_scale=30, linewidth=3, color='black') |
| | ax.add_patch(arrow1) |
| | |
| | |
| | box2 = FancyBboxPatch((1, 7), 8, 1.5, boxstyle="round,pad=0.1", |
| | facecolor=colors['reasoning'], edgecolor='black', linewidth=2) |
| | ax.add_patch(box2) |
| | ax.text(5, 8, 'Step-Level Reasoning Generator', ha='center', va='center', |
| | fontsize=12, fontweight='bold', color='white') |
| | ax.text(5, 7.6, '• Iterative step generation', ha='center', va='center', |
| | fontsize=8, color='white') |
| | ax.text(5, 7.3, '• Visual features + hidden states extraction', ha='center', va='center', |
| | fontsize=8, color='white') |
| | |
| | |
| | arrow2 = FancyArrowPatch((5, 7), (5, 6.2), arrowstyle='->', |
| | mutation_scale=30, linewidth=3, color='black') |
| | ax.add_patch(arrow2) |
| | |
| | |
| | box3 = FancyBboxPatch((0.5, 4.5), 9, 1.5, boxstyle="round,pad=0.1", |
| | facecolor=colors['prm'], edgecolor='black', linewidth=2) |
| | ax.add_patch(box3) |
| | ax.text(5, 5.5, 'Process Reward Model (PRM)', ha='center', va='center', |
| | fontsize=12, fontweight='bold', color='white') |
| | ax.text(2.5, 5, '• Coherence', ha='left', va='center', fontsize=8, color='white') |
| | ax.text(2.5, 4.75, '• Relevance', ha='left', va='center', fontsize=8, color='white') |
| | ax.text(5, 5, '• Correctness', ha='left', va='center', fontsize=8, color='white') |
| | ax.text(5, 4.75, '• Informativeness', ha='left', va='center', fontsize=8, color='white') |
| | ax.text(7.5, 5, '• Confidence', ha='left', va='center', fontsize=8, color='white') |
| | |
| | |
| | arrow3 = FancyArrowPatch((5, 4.5), (5, 3.7), arrowstyle='->', |
| | mutation_scale=30, linewidth=3, color='black') |
| | ax.add_patch(arrow3) |
| | |
| | |
| | box4 = FancyBboxPatch((1, 2), 8, 1.5, boxstyle="round,pad=0.1", |
| | facecolor=colors['rl'], edgecolor='black', linewidth=2) |
| | ax.add_patch(box4) |
| | ax.text(5, 3, 'Reinforcement Learning (PPO)', ha='center', va='center', |
| | fontsize=12, fontweight='bold', color='white') |
| | ax.text(5, 2.6, '• Policy optimization with step-level rewards', ha='center', va='center', |
| | fontsize=8, color='white') |
| | ax.text(5, 2.3, '• Value function + GAE', ha='center', va='center', |
| | fontsize=8, color='white') |
| | |
| | |
| | arrow4 = FancyArrowPatch((5, 2), (5, 1.2), arrowstyle='->', |
| | mutation_scale=30, linewidth=3, color='black') |
| | ax.add_patch(arrow4) |
| | |
| | |
| | box5 = FancyBboxPatch((1, 0), 8, 1, boxstyle="round,pad=0.1", |
| | facecolor=colors['inference'], edgecolor='black', linewidth=2) |
| | ax.add_patch(box5) |
| | ax.text(5, 0.5, 'Inference-Time Scaling', ha='center', va='center', |
| | fontsize=12, fontweight='bold', color='white') |
| | ax.text(2.5, 0.2, 'Best-of-N', ha='center', va='center', fontsize=8, color='white') |
| | ax.text(5, 0.2, 'Majority Vote', ha='center', va='center', fontsize=8, color='white') |
| | ax.text(7.5, 0.2, 'Weighted Vote', ha='center', va='center', fontsize=8, color='white') |
| | |
| | plt.tight_layout() |
| | plt.savefig(save_path, dpi=300, bbox_inches='tight') |
| | print(f"Architecture diagram saved to {save_path}") |
| | plt.close() |
| |
|
| |
|
| | def visualize_reasoning_example(save_path: str = "reasoning_example.png"): |
| | """Visualize an example reasoning chain.""" |
| | fig, ax = plt.subplots(figsize=(14, 10)) |
| | ax.set_xlim(0, 10) |
| | ax.set_ylim(0, 8) |
| | ax.axis('off') |
| | |
| | |
| | ax.text(5, 7.5, 'Example: Step-Level Reasoning Chain', |
| | ha='center', fontsize=16, fontweight='bold') |
| | |
| | |
| | question_box = FancyBboxPatch((0.5, 6.5), 9, 0.6, boxstyle="round,pad=0.05", |
| | facecolor='#ecf0f1', edgecolor='black', linewidth=2) |
| | ax.add_patch(question_box) |
| | ax.text(5, 6.8, 'Question: "How many red balls are in the image?"', |
| | ha='center', va='center', fontsize=11, fontweight='bold') |
| | |
| | |
| | steps = [ |
| | { |
| | 'type': 'PERCEPTION', |
| | 'desc': 'I observe multiple objects in the image', |
| | 'reward': 0.85, |
| | 'confidence': 0.9, |
| | 'color': '#3498db' |
| | }, |
| | { |
| | 'type': 'LOCALIZATION', |
| | 'desc': 'I identify the locations of ball-shaped objects', |
| | 'reward': 0.80, |
| | 'confidence': 0.85, |
| | 'color': '#2ecc71' |
| | }, |
| | { |
| | 'type': 'COMPARISON', |
| | 'desc': 'I determine which balls are red colored', |
| | 'reward': 0.88, |
| | 'confidence': 0.87, |
| | 'color': '#e74c3c' |
| | }, |
| | { |
| | 'type': 'COUNTING', |
| | 'desc': 'I count the red balls: 1, 2, 3', |
| | 'reward': 0.92, |
| | 'confidence': 0.95, |
| | 'color': '#f39c12' |
| | }, |
| | { |
| | 'type': 'VERIFICATION', |
| | 'desc': 'I verify my count is correct', |
| | 'reward': 0.90, |
| | 'confidence': 0.93, |
| | 'color': '#9b59b6' |
| | } |
| | ] |
| | |
| | y_pos = 5.5 |
| | for i, step in enumerate(steps): |
| | |
| | box = FancyBboxPatch((0.5, y_pos - 0.6), 6.5, 0.5, boxstyle="round,pad=0.05", |
| | facecolor=step['color'], edgecolor='black', linewidth=1.5, alpha=0.7) |
| | ax.add_patch(box) |
| | |
| | |
| | ax.text(0.7, y_pos - 0.15, f"Step {i+1}: {step['type']}", |
| | ha='left', va='top', fontsize=9, fontweight='bold', color='white') |
| | ax.text(0.7, y_pos - 0.45, step['desc'], |
| | ha='left', va='top', fontsize=8, color='white') |
| | |
| | |
| | metrics_box = FancyBboxPatch((7.2, y_pos - 0.6), 2.3, 0.5, boxstyle="round,pad=0.05", |
| | facecolor='white', edgecolor='black', linewidth=1) |
| | ax.add_patch(metrics_box) |
| | ax.text(7.4, y_pos - 0.25, f"Reward: {step['reward']:.2f}", |
| | ha='left', va='center', fontsize=8) |
| | ax.text(7.4, y_pos - 0.5, f"Confidence: {step['confidence']:.2f}", |
| | ha='left', va='center', fontsize=8) |
| | |
| | |
| | if i < len(steps) - 1: |
| | arrow = FancyArrowPatch((3.5, y_pos - 0.65), (3.5, y_pos - 0.95), |
| | arrowstyle='->', mutation_scale=15, linewidth=2, color='gray') |
| | ax.add_patch(arrow) |
| | |
| | y_pos -= 1.0 |
| | |
| | |
| | answer_box = FancyBboxPatch((0.5, 0.3), 9, 0.6, boxstyle="round,pad=0.05", |
| | facecolor='#27ae60', edgecolor='black', linewidth=2) |
| | ax.add_patch(answer_box) |
| | ax.text(5, 0.6, 'Final Answer: "There are 3 red balls in the image"', |
| | ha='center', va='center', fontsize=11, fontweight='bold', color='white') |
| | ax.text(5, 0.35, 'Total Reward: 4.35 | Avg Confidence: 0.90', |
| | ha='center', va='center', fontsize=9, color='white') |
| | |
| | plt.tight_layout() |
| | plt.savefig(save_path, dpi=300, bbox_inches='tight') |
| | print(f"Reasoning example saved to {save_path}") |
| | plt.close() |
| |
|
| |
|
| | def visualize_inference_scaling(save_path: str = "inference_scaling.png"): |
| | """Visualize inference-time scaling benefits.""" |
| | fig, axes = plt.subplots(1, 3, figsize=(18, 5)) |
| | |
| | |
| | num_samples = [1, 2, 4, 8, 16, 32] |
| | |
| | |
| | accuracy = [65.3, 68.1, 71.4, 73.8, 75.2, 75.8] |
| | axes[0].plot(num_samples, accuracy, marker='o', linewidth=2, markersize=10, color='#2ecc71') |
| | axes[0].set_xlabel('Number of Samples', fontsize=12, fontweight='bold') |
| | axes[0].set_ylabel('Accuracy (%)', fontsize=12, fontweight='bold') |
| | axes[0].set_title('Accuracy vs Inference Samples', fontsize=14, fontweight='bold') |
| | axes[0].grid(True, alpha=0.3) |
| | axes[0].set_xscale('log', base=2) |
| | |
| | |
| | axes[0].annotate(f'+{accuracy[-1] - accuracy[0]:.1f}%', |
| | xy=(32, accuracy[-1]), xytext=(20, 72), |
| | arrowprops=dict(arrowstyle='->', color='red', lw=2), |
| | fontsize=12, fontweight='bold', color='red') |
| | |
| | |
| | methods = ['Single\nSample', 'Best-of-8', 'Majority\nVote', 'Weighted\nVote'] |
| | accuracies = [65.3, 73.8, 74.2, 75.1] |
| | colors_bar = ['#95a5a6', '#3498db', '#e74c3c', '#f39c12'] |
| | |
| | bars = axes[1].bar(methods, accuracies, color=colors_bar, edgecolor='black', linewidth=1.5) |
| | axes[1].set_ylabel('Accuracy (%)', fontsize=12, fontweight='bold') |
| | axes[1].set_title('Aggregation Method Comparison', fontsize=14, fontweight='bold') |
| | axes[1].set_ylim([60, 80]) |
| | axes[1].grid(True, alpha=0.3, axis='y') |
| | |
| | |
| | for bar, acc in zip(bars, accuracies): |
| | height = bar.get_height() |
| | axes[1].text(bar.get_x() + bar.get_width()/2., height + 0.5, |
| | f'{acc:.1f}%', ha='center', va='bottom', fontweight='bold') |
| | |
| | |
| | np.random.seed(42) |
| | all_rewards = np.random.normal(0.7, 0.15, 1000) |
| | best_rewards = np.random.normal(0.85, 0.08, 200) |
| | |
| | axes[2].hist(all_rewards, bins=30, alpha=0.7, label='All Samples', color='#95a5a6', edgecolor='black') |
| | axes[2].hist(best_rewards, bins=30, alpha=0.8, label='Best-of-N Selected', color='#2ecc71', edgecolor='black') |
| | axes[2].set_xlabel('Reward Score', fontsize=12, fontweight='bold') |
| | axes[2].set_ylabel('Frequency', fontsize=12, fontweight='bold') |
| | axes[2].set_title('Reward Distribution: Selection Effect', fontsize=14, fontweight='bold') |
| | axes[2].legend(fontsize=10, loc='upper left') |
| | axes[2].grid(True, alpha=0.3, axis='y') |
| | |
| | plt.tight_layout() |
| | plt.savefig(save_path, dpi=300, bbox_inches='tight') |
| | print(f"Inference scaling visualization saved to {save_path}") |
| | plt.close() |
| |
|
| |
|
| | def visualize_training_phases(save_path: str = "training_phases.png"): |
| | """Visualize the three training phases.""" |
| | fig, axes = plt.subplots(1, 3, figsize=(18, 5)) |
| | |
| | |
| | epochs = np.arange(1, 11) |
| | prm_loss = 0.5 * np.exp(-epochs/3) + 0.1 |
| | prm_corr = 0.9 * (1 - np.exp(-epochs/2)) |
| | |
| | ax1 = axes[0] |
| | ax1_twin = ax1.twinx() |
| | |
| | line1 = ax1.plot(epochs, prm_loss, 'o-', color='#e74c3c', linewidth=2, label='Loss') |
| | line2 = ax1_twin.plot(epochs, prm_corr, 's-', color='#2ecc71', linewidth=2, label='Correlation') |
| | |
| | ax1.set_xlabel('Epoch', fontsize=12, fontweight='bold') |
| | ax1.set_ylabel('MSE Loss', fontsize=12, fontweight='bold', color='#e74c3c') |
| | ax1_twin.set_ylabel('Correlation', fontsize=12, fontweight='bold', color='#2ecc71') |
| | ax1.set_title('Phase 1: PRM Training', fontsize=14, fontweight='bold') |
| | ax1.tick_params(axis='y', labelcolor='#e74c3c') |
| | ax1_twin.tick_params(axis='y', labelcolor='#2ecc71') |
| | ax1.grid(True, alpha=0.3) |
| | |
| | lines = line1 + line2 |
| | labels = [l.get_label() for l in lines] |
| | ax1.legend(lines, labels, loc='center right') |
| | |
| | |
| | iterations = np.arange(0, 1000, 50) |
| | avg_reward = 0.5 + 0.3 * (1 - np.exp(-iterations/200)) |
| | policy_loss = 0.8 * np.exp(-iterations/250) + 0.2 |
| | |
| | ax2 = axes[1] |
| | ax2_twin = ax2.twinx() |
| | |
| | line1 = ax2.plot(iterations, avg_reward, 'o-', color='#2ecc71', linewidth=2, label='Avg Reward') |
| | line2 = ax2_twin.plot(iterations, policy_loss, 's-', color='#f39c12', linewidth=2, label='Policy Loss') |
| | |
| | ax2.set_xlabel('Iteration', fontsize=12, fontweight='bold') |
| | ax2.set_ylabel('Average Reward', fontsize=12, fontweight='bold', color='#2ecc71') |
| | ax2_twin.set_ylabel('Policy Loss', fontsize=12, fontweight='bold', color='#f39c12') |
| | ax2.set_title('Phase 2: RL Training (PPO)', fontsize=14, fontweight='bold') |
| | ax2.tick_params(axis='y', labelcolor='#2ecc71') |
| | ax2_twin.tick_params(axis='y', labelcolor='#f39c12') |
| | ax2.grid(True, alpha=0.3) |
| | |
| | lines = line1 + line2 |
| | labels = [l.get_label() for l in lines] |
| | ax2.legend(lines, labels, loc='center right') |
| | |
| | |
| | components = ['Baseline', '+ PRM', '+ RL', '+ Inference\nScaling'] |
| | accuracy = [65.3, 71.2, 73.5, 75.8] |
| | colors = ['#95a5a6', '#e74c3c', '#f39c12', '#2ecc71'] |
| | |
| | bars = axes[2].bar(components, accuracy, color=colors, edgecolor='black', linewidth=2) |
| | axes[2].set_ylabel('Accuracy (%)', fontsize=12, fontweight='bold') |
| | axes[2].set_title('Phase 3: Component Ablation', fontsize=14, fontweight='bold') |
| | axes[2].set_ylim([60, 80]) |
| | axes[2].grid(True, alpha=0.3, axis='y') |
| | |
| | for bar, acc in zip(bars, accuracy): |
| | height = bar.get_height() |
| | axes[2].text(bar.get_x() + bar.get_width()/2., height + 0.5, |
| | f'{acc:.1f}%', ha='center', va='bottom', fontweight='bold', fontsize=10) |
| | |
| | plt.tight_layout() |
| | plt.savefig(save_path, dpi=300, bbox_inches='tight') |
| | print(f"Training phases visualization saved to {save_path}") |
| | plt.close() |
| |
|
| |
|
| | def main(): |
| | """Generate all visualizations.""" |
| | print("Generating Step-Level CoT Visualizations...") |
| | print("=" * 60) |
| | |
| | visualize_architecture() |
| | visualize_reasoning_example() |
| | visualize_inference_scaling() |
| | visualize_training_phases() |
| | |
| | print("=" * 60) |
| | print("All visualizations generated successfully!") |
| | print("\nGenerated files:") |
| | print(" - step_cot_architecture.png") |
| | print(" - reasoning_example.png") |
| | print(" - inference_scaling.png") |
| | print(" - training_phases.png") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|