""" Data structures for step-level reasoning chains with vision-language annotations. """ from dataclasses import dataclass, field from typing import Dict, List, Optional, Any from enum import Enum import torch from torch.utils.data import Dataset import json # Special tokens for chain-of-thought reasoning SPECIAL_TOKENS = [ # Reasoning structure "<|reasoning_start|>", "<|reasoning_end|>", # Step-level markers "<|step_start|>", "<|step_end|>", "<|step_type|>", # Step components "<|description_start|>", "<|description_end|>", "<|confidence_start|>", "<|confidence_end|>", # Dependencies "<|depends_on|>", # PRM evaluation tokens (from PRM800K) "ки", # step separator for PRM scoring "к+и", # positive step label "к-и", # negative step label # Final answer "<|answer_start|>", "<|answer_end|>", # Region tokens (inspired by Chain-of-Sight for visual grounding) "<|region|>", "<|bbox|>", # Region markers and bounding box coordinates ] SPECIAL_TOKEN_DESCRIPTIONS = [ "start of reasoning", "end of reasoning", "start of step", "end of step", "step type indicator", "start of description", "end of description", "start of confidence", "end of confidence", "depends on previous step", "step separator for PRM", "positive step", "negative step", "start of answer", "end of answer", "region reference", "bounding box coordinates", ] # Generate additional region tokens like , , ... # for referring to specific visual regions (Chain-of-Sight style) REGION_TOKENS = [f"" for i in range(1, 11)] SPECIAL_TOKENS.extend(REGION_TOKENS) SPECIAL_TOKEN_DESCRIPTIONS.extend([f"visual region {i}" for i in range(1, 11)]) class StepType(Enum): """Types of reasoning steps in vision-language tasks.""" PERCEPTION = "perception" # Visual understanding LOCALIZATION = "localization" # Spatial reasoning COMPARISON = "comparison" # Comparing visual elements COUNTING = "counting" # Numerical reasoning INFERENCE = "inference" # Logical deduction COMPOSITION = "composition" # Combining information VERIFICATION = "verification" # Checking consistency @dataclass class ReasoningStep: """ A single step in the chain of thought reasoning process. Attributes: step_id: Unique identifier for this step step_type: Type of reasoning performed description: Natural language description of the reasoning visual_features: Extracted visual features relevant to this step attention_map: Spatial attention weights (H, W) confidence: Model's confidence in this step [0, 1] reward: Process reward score assigned by PRM [-1, 1] dependencies: IDs of previous steps this depends on bounding_boxes: List of [x1, y1, x2, y2] bounding boxes (Chain-of-Sight style) region_ids: List of region IDs referenced in this step (e.g., [1, 2]) metadata: Additional step-specific information """ step_id: int step_type: StepType description: str visual_features: Optional[torch.Tensor] = None attention_map: Optional[torch.Tensor] = None confidence: float = 0.0 reward: float = 0.0 dependencies: List[int] = field(default_factory=list) bounding_boxes: List[List[float]] = field(default_factory=list) # [[x1,y1,x2,y2], ...] region_ids: List[int] = field(default_factory=list) # [1, 2, 3] metadata: Dict[str, Any] = field(default_factory=dict) def to_dict(self) -> Dict[str, Any]: """Convert to dictionary for serialization.""" return { 'step_id': self.step_id, 'step_type': self.step_type.value, 'description': self.description, 'confidence': self.confidence, 'reward': self.reward, 'dependencies': self.dependencies, 'metadata': self.metadata, 'has_visual_features': self.visual_features is not None, 'has_attention_map': self.attention_map is not None, } def format_with_tokens(self) -> str: """Format step with special tokens for training (with Chain-of-Sight region support).""" parts = ["<|step_start|>"] # Add step type parts.append(f"<|step_type|>{self.step_type.value}") # Add dependencies if any if self.dependencies: deps_str = ",".join(map(str, self.dependencies)) parts.append(f"<|depends_on|>{deps_str}") # Add region references (Chain-of-Sight style) if self.region_ids: region_refs = " ".join([f"" for rid in self.region_ids]) parts.append(f"<|region|>{region_refs}") # Add bounding boxes if present if self.bounding_boxes: bbox_strs = [f"[{b[0]:.3f},{b[1]:.3f},{b[2]:.3f},{b[3]:.3f}]" for b in self.bounding_boxes] parts.append(f"<|bbox|>{','.join(bbox_strs)}") # Add description parts.append(f"<|description_start|>{self.description}<|description_end|>") # Add confidence parts.append(f"<|confidence_start|>{self.confidence:.3f}<|confidence_end|>") # Add PRM separator parts.append("ки") parts.append("<|step_end|>") return " ".join(parts) @classmethod def from_dict(cls, data: Dict[str, Any]) -> 'ReasoningStep': """Create from dictionary.""" return cls( step_id=data['step_id'], step_type=StepType(data['step_type']), description=data['description'], confidence=data.get('confidence', 0.0), reward=data.get('reward', 0.0), dependencies=data.get('dependencies', []), metadata=data.get('metadata', {}), ) @dataclass class ReasoningChain: """ Complete chain of reasoning steps for a vision-language task. Attributes: chain_id: Unique identifier for this reasoning chain image_path: Path to the input image image: PIL Image object (for streaming datasets) prompt: Input text prompt/question steps: Ordered list of reasoning steps final_answer: Final output/answer total_reward: Cumulative reward from PRM is_correct: Whether final answer is correct metadata: Additional chain-level information """ chain_id: str image_path: str = None image: Any = None prompt: str = "" steps: List[ReasoningStep] = field(default_factory=list) final_answer: str = "" total_reward: float = 0.0 is_correct: bool = False metadata: Dict[str, Any] = field(default_factory=dict) def __len__(self) -> int: """Number of reasoning steps.""" return len(self.steps) def get_step_rewards(self) -> List[float]: """Extract reward for each step.""" return [step.reward for step in self.steps] def get_cumulative_rewards(self) -> List[float]: """Calculate cumulative rewards up to each step.""" cumulative = [] total = 0.0 for step in self.steps: total += step.reward cumulative.append(total) return cumulative def to_dict(self) -> Dict[str, Any]: """Convert to dictionary for serialization.""" return { 'chain_id': self.chain_id, 'image_path': self.image_path, 'prompt': self.prompt, 'steps': [step.to_dict() for step in self.steps], 'final_answer': self.final_answer, 'total_reward': self.total_reward, 'is_correct': self.is_correct, 'metadata': self.metadata } def format_with_tokens(self) -> str: """Format entire reasoning chain with special tokens for training.""" parts = ["<|reasoning_start|>"] # Add all steps for step in self.steps: parts.append(step.format_with_tokens()) parts.append("<|reasoning_end|>") # Add final answer parts.append(f"<|answer_start|>{self.final_answer}<|answer_end|>") return " ".join(parts) @classmethod def from_dict(cls, data: Dict[str, Any]) -> 'ReasoningChain': """Create from dictionary.""" return cls( chain_id=data['chain_id'], image_path=data['image_path'], prompt=data['prompt'], steps=[ReasoningStep.from_dict(s) for s in data['steps']], final_answer=data['final_answer'], total_reward=data.get('total_reward', 0.0), is_correct=data.get('is_correct', False), metadata=data.get('metadata', {}), ) def save_json(self, path: str) -> None: """Save reasoning chain to JSON file.""" with open(path, 'w', encoding='utf-8') as f: json.dump(self.to_dict(), f, indent=2, ensure_ascii=False) @classmethod def load_json(cls, path: str) -> 'ReasoningChain': """Load reasoning chain from JSON file.""" with open(path, 'r', encoding='utf-8') as f: data = json.load(f) return cls.from_dict(data) class StepDataset(Dataset): """ Dataset for training with step-level reasoning chains. Supports both supervised fine-tuning and RL training with process rewards. """ def __init__( self, reasoning_chains: List[ReasoningChain], tokenizer: Any, max_steps: int = 10, include_visual_features: bool = True, filter_incorrect: bool = False, ): """ Initialize step dataset. Args: reasoning_chains: List of annotated reasoning chains tokenizer: Tokenizer for text encoding max_steps: Maximum number of steps per chain include_visual_features: Whether to include visual features filter_incorrect: If True, only include correct chains """ self.tokenizer = tokenizer self.max_steps = max_steps self.include_visual_features = include_visual_features # Filter chains if needed if filter_incorrect: self.chains = [c for c in reasoning_chains if c.is_correct] else: self.chains = reasoning_chains # Pre-compute statistics self.avg_steps = sum(len(c) for c in self.chains) / len(self.chains) if self.chains else 0 self.total_steps = sum(len(c) for c in self.chains) def __len__(self) -> int: return len(self.chains) def __getitem__(self, idx: int) -> Dict[str, Any]: """ Get a single reasoning chain with all steps. Returns: Dictionary containing: - prompt_ids: Tokenized prompt - formatted_input_ids: Tokenized chain with special tokens - formatted_text: Raw text with special tokens - step_descriptions: List of step descriptions - step_types: List of step types - step_rewards: List of process rewards - visual_features: List of visual features (if enabled) - attention_maps: List of attention maps (if available) - final_answer: Target answer - is_correct: Ground truth correctness - chain_metadata: Additional information """ chain = self.chains[idx] # Format chain with special tokens formatted_text = chain.format_with_tokens() # Tokenize formatted text (for training) formatted_encoding = self.tokenizer( formatted_text, truncation=True, max_length=2048, # Longer for full reasoning chains return_tensors='pt', ) # Tokenize prompt (for compatibility) prompt_encoding = self.tokenizer( chain.prompt, truncation=True, max_length=512, return_tensors='pt', ) # Prepare step data step_descriptions = [] step_types = [] step_rewards = [] visual_features = [] attention_maps = [] for step in chain.steps[:self.max_steps]: step_descriptions.append(step.description) step_types.append(step.step_type.value) step_rewards.append(step.reward) if self.include_visual_features: if step.visual_features is not None: visual_features.append(step.visual_features) else: visual_features.append(torch.zeros(768)) # Default dimension if step.attention_map is not None: attention_maps.append(step.attention_map) # Tokenize final answer answer_encoding = self.tokenizer( chain.final_answer, truncation=True, max_length=256, return_tensors='pt', ) result = { 'prompt_ids': prompt_encoding['input_ids'].squeeze(0), 'prompt_mask': prompt_encoding['attention_mask'].squeeze(0), 'formatted_input_ids': formatted_encoding['input_ids'].squeeze(0), 'formatted_attention_mask': formatted_encoding['attention_mask'].squeeze(0), 'formatted_text': formatted_text, 'step_descriptions': step_descriptions, 'step_types': step_types, 'step_rewards': torch.tensor(step_rewards, dtype=torch.float32), 'answer_ids': answer_encoding['input_ids'].squeeze(0), 'answer_mask': answer_encoding['attention_mask'].squeeze(0), 'is_correct': torch.tensor(chain.is_correct, dtype=torch.bool), 'total_reward': torch.tensor(chain.total_reward, dtype=torch.float32), 'num_steps': len(chain.steps), 'chain_id': chain.chain_id, 'image_path': chain.image_path, } if self.include_visual_features and visual_features: result['visual_features'] = torch.stack(visual_features) if attention_maps: result['attention_maps'] = attention_maps return result @classmethod def from_json_dir( cls, json_dir: str, tokenizer: Any, **kwargs ) -> 'StepDataset': """ Load dataset from directory of JSON files. Args: json_dir: Directory containing reasoning chain JSON files tokenizer: Tokenizer for text encoding **kwargs: Additional arguments for StepDataset """ from pathlib import Path json_path = Path(json_dir) chains = [] for json_file in json_path.glob('*.json'): try: chain = ReasoningChain.load_json(str(json_file)) chains.append(chain) except Exception as e: print(f"Error loading {json_file}: {e}") return cls(chains, tokenizer, **kwargs) def get_statistics(self) -> Dict[str, Any]: """Calculate dataset statistics.""" step_types = {} rewards = [] correct_count = 0 for chain in self.chains: if chain.is_correct: correct_count += 1 for step in chain.steps: step_type = step.step_type.value step_types[step_type] = step_types.get(step_type, 0) + 1 rewards.append(step.reward) return { 'num_chains': len(self.chains), 'total_steps': self.total_steps, 'avg_steps_per_chain': self.avg_steps, 'accuracy': correct_count / len(self.chains) if self.chains else 0, 'step_type_distribution': step_types, 'avg_step_reward': sum(rewards) / len(rewards) if rewards else 0, 'min_step_reward': min(rewards) if rewards else 0, 'max_step_reward': max(rewards) if rewards else 0, }