| | """ |
| | Data structures for step-level reasoning chains with vision-language annotations. |
| | """ |
| |
|
| | from dataclasses import dataclass, field |
| | from typing import Dict, List, Optional, Any |
| | from enum import Enum |
| | import torch |
| | from torch.utils.data import Dataset |
| | import json |
| |
|
| | |
| | SPECIAL_TOKENS = [ |
| | |
| | "<|reasoning_start|>", "<|reasoning_end|>", |
| | |
| | |
| | "<|step_start|>", "<|step_end|>", |
| | "<|step_type|>", |
| | |
| | |
| | "<|description_start|>", "<|description_end|>", |
| | "<|confidence_start|>", "<|confidence_end|>", |
| | |
| | |
| | "<|depends_on|>", |
| | |
| | |
| | "ки", |
| | "к+и", |
| | "к-и", |
| | |
| | |
| | "<|answer_start|>", "<|answer_end|>", |
| | |
| | |
| | "<|region|>", "<|bbox|>", |
| | ] |
| |
|
| | SPECIAL_TOKEN_DESCRIPTIONS = [ |
| | "start of reasoning", "end of reasoning", |
| | "start of step", "end of step", |
| | "step type indicator", |
| | "start of description", "end of description", |
| | "start of confidence", "end of confidence", |
| | "depends on previous step", |
| | "step separator for PRM", "positive step", "negative step", |
| | "start of answer", "end of answer", |
| | "region reference", "bounding box coordinates", |
| | ] |
| |
|
| | |
| | |
| | REGION_TOKENS = [f"<region{i}>" for i in range(1, 11)] |
| | SPECIAL_TOKENS.extend(REGION_TOKENS) |
| | SPECIAL_TOKEN_DESCRIPTIONS.extend([f"visual region {i}" for i in range(1, 11)]) |
| |
|
| |
|
| | class StepType(Enum): |
| | """Types of reasoning steps in vision-language tasks.""" |
| | PERCEPTION = "perception" |
| | LOCALIZATION = "localization" |
| | COMPARISON = "comparison" |
| | COUNTING = "counting" |
| | INFERENCE = "inference" |
| | COMPOSITION = "composition" |
| | VERIFICATION = "verification" |
| |
|
| |
|
| | @dataclass |
| | class ReasoningStep: |
| | """ |
| | A single step in the chain of thought reasoning process. |
| | |
| | Attributes: |
| | step_id: Unique identifier for this step |
| | step_type: Type of reasoning performed |
| | description: Natural language description of the reasoning |
| | visual_features: Extracted visual features relevant to this step |
| | attention_map: Spatial attention weights (H, W) |
| | confidence: Model's confidence in this step [0, 1] |
| | reward: Process reward score assigned by PRM [-1, 1] |
| | dependencies: IDs of previous steps this depends on |
| | bounding_boxes: List of [x1, y1, x2, y2] bounding boxes (Chain-of-Sight style) |
| | region_ids: List of region IDs referenced in this step (e.g., [1, 2]) |
| | metadata: Additional step-specific information |
| | """ |
| | step_id: int |
| | step_type: StepType |
| | description: str |
| | visual_features: Optional[torch.Tensor] = None |
| | attention_map: Optional[torch.Tensor] = None |
| | confidence: float = 0.0 |
| | reward: float = 0.0 |
| | dependencies: List[int] = field(default_factory=list) |
| | bounding_boxes: List[List[float]] = field(default_factory=list) |
| | region_ids: List[int] = field(default_factory=list) |
| | metadata: Dict[str, Any] = field(default_factory=dict) |
| | |
| | def to_dict(self) -> Dict[str, Any]: |
| | """Convert to dictionary for serialization.""" |
| | return { |
| | 'step_id': self.step_id, |
| | 'step_type': self.step_type.value, |
| | 'description': self.description, |
| | 'confidence': self.confidence, |
| | 'reward': self.reward, |
| | 'dependencies': self.dependencies, |
| | 'metadata': self.metadata, |
| | 'has_visual_features': self.visual_features is not None, |
| | 'has_attention_map': self.attention_map is not None, |
| | } |
| | |
| | def format_with_tokens(self) -> str: |
| | """Format step with special tokens for training (with Chain-of-Sight region support).""" |
| | parts = ["<|step_start|>"] |
| | |
| | |
| | parts.append(f"<|step_type|>{self.step_type.value}") |
| | |
| | |
| | if self.dependencies: |
| | deps_str = ",".join(map(str, self.dependencies)) |
| | parts.append(f"<|depends_on|>{deps_str}") |
| | |
| | |
| | if self.region_ids: |
| | region_refs = " ".join([f"<region{rid}>" for rid in self.region_ids]) |
| | parts.append(f"<|region|>{region_refs}") |
| | |
| | |
| | if self.bounding_boxes: |
| | bbox_strs = [f"[{b[0]:.3f},{b[1]:.3f},{b[2]:.3f},{b[3]:.3f}]" for b in self.bounding_boxes] |
| | parts.append(f"<|bbox|>{','.join(bbox_strs)}") |
| | |
| | |
| | parts.append(f"<|description_start|>{self.description}<|description_end|>") |
| | |
| | |
| | parts.append(f"<|confidence_start|>{self.confidence:.3f}<|confidence_end|>") |
| | |
| | |
| | parts.append("ки") |
| | |
| | parts.append("<|step_end|>") |
| | |
| | return " ".join(parts) |
| | |
| | @classmethod |
| | def from_dict(cls, data: Dict[str, Any]) -> 'ReasoningStep': |
| | """Create from dictionary.""" |
| | return cls( |
| | step_id=data['step_id'], |
| | step_type=StepType(data['step_type']), |
| | description=data['description'], |
| | confidence=data.get('confidence', 0.0), |
| | reward=data.get('reward', 0.0), |
| | dependencies=data.get('dependencies', []), |
| | metadata=data.get('metadata', {}), |
| | ) |
| |
|
| |
|
| | @dataclass |
| | class ReasoningChain: |
| | """ |
| | Complete chain of reasoning steps for a vision-language task. |
| | |
| | Attributes: |
| | chain_id: Unique identifier for this reasoning chain |
| | image_path: Path to the input image |
| | image: PIL Image object (for streaming datasets) |
| | prompt: Input text prompt/question |
| | steps: Ordered list of reasoning steps |
| | final_answer: Final output/answer |
| | total_reward: Cumulative reward from PRM |
| | is_correct: Whether final answer is correct |
| | metadata: Additional chain-level information |
| | """ |
| | chain_id: str |
| | image_path: str = None |
| | image: Any = None |
| | prompt: str = "" |
| | steps: List[ReasoningStep] = field(default_factory=list) |
| | final_answer: str = "" |
| | total_reward: float = 0.0 |
| | is_correct: bool = False |
| | metadata: Dict[str, Any] = field(default_factory=dict) |
| | |
| | def __len__(self) -> int: |
| | """Number of reasoning steps.""" |
| | return len(self.steps) |
| | |
| | def get_step_rewards(self) -> List[float]: |
| | """Extract reward for each step.""" |
| | return [step.reward for step in self.steps] |
| | |
| | def get_cumulative_rewards(self) -> List[float]: |
| | """Calculate cumulative rewards up to each step.""" |
| | cumulative = [] |
| | total = 0.0 |
| | for step in self.steps: |
| | total += step.reward |
| | cumulative.append(total) |
| | return cumulative |
| | |
| | def to_dict(self) -> Dict[str, Any]: |
| | """Convert to dictionary for serialization.""" |
| | return { |
| | 'chain_id': self.chain_id, |
| | 'image_path': self.image_path, |
| | 'prompt': self.prompt, |
| | 'steps': [step.to_dict() for step in self.steps], |
| | 'final_answer': self.final_answer, |
| | 'total_reward': self.total_reward, |
| | 'is_correct': self.is_correct, |
| | 'metadata': self.metadata |
| | } |
| | |
| | def format_with_tokens(self) -> str: |
| | """Format entire reasoning chain with special tokens for training.""" |
| | parts = ["<|reasoning_start|>"] |
| | |
| | |
| | for step in self.steps: |
| | parts.append(step.format_with_tokens()) |
| | |
| | parts.append("<|reasoning_end|>") |
| | |
| | |
| | parts.append(f"<|answer_start|>{self.final_answer}<|answer_end|>") |
| | |
| | return " ".join(parts) |
| | |
| | @classmethod |
| | def from_dict(cls, data: Dict[str, Any]) -> 'ReasoningChain': |
| | """Create from dictionary.""" |
| | return cls( |
| | chain_id=data['chain_id'], |
| | image_path=data['image_path'], |
| | prompt=data['prompt'], |
| | steps=[ReasoningStep.from_dict(s) for s in data['steps']], |
| | final_answer=data['final_answer'], |
| | total_reward=data.get('total_reward', 0.0), |
| | is_correct=data.get('is_correct', False), |
| | metadata=data.get('metadata', {}), |
| | ) |
| | |
| | def save_json(self, path: str) -> None: |
| | """Save reasoning chain to JSON file.""" |
| | with open(path, 'w', encoding='utf-8') as f: |
| | json.dump(self.to_dict(), f, indent=2, ensure_ascii=False) |
| | |
| | @classmethod |
| | def load_json(cls, path: str) -> 'ReasoningChain': |
| | """Load reasoning chain from JSON file.""" |
| | with open(path, 'r', encoding='utf-8') as f: |
| | data = json.load(f) |
| | return cls.from_dict(data) |
| |
|
| |
|
| | class StepDataset(Dataset): |
| | """ |
| | Dataset for training with step-level reasoning chains. |
| | |
| | Supports both supervised fine-tuning and RL training with process rewards. |
| | """ |
| | |
| | def __init__( |
| | self, |
| | reasoning_chains: List[ReasoningChain], |
| | tokenizer: Any, |
| | max_steps: int = 10, |
| | include_visual_features: bool = True, |
| | filter_incorrect: bool = False, |
| | ): |
| | """ |
| | Initialize step dataset. |
| | |
| | Args: |
| | reasoning_chains: List of annotated reasoning chains |
| | tokenizer: Tokenizer for text encoding |
| | max_steps: Maximum number of steps per chain |
| | include_visual_features: Whether to include visual features |
| | filter_incorrect: If True, only include correct chains |
| | """ |
| | self.tokenizer = tokenizer |
| | self.max_steps = max_steps |
| | self.include_visual_features = include_visual_features |
| | |
| | |
| | if filter_incorrect: |
| | self.chains = [c for c in reasoning_chains if c.is_correct] |
| | else: |
| | self.chains = reasoning_chains |
| | |
| | |
| | self.avg_steps = sum(len(c) for c in self.chains) / len(self.chains) if self.chains else 0 |
| | self.total_steps = sum(len(c) for c in self.chains) |
| | |
| | def __len__(self) -> int: |
| | return len(self.chains) |
| | |
| | def __getitem__(self, idx: int) -> Dict[str, Any]: |
| | """ |
| | Get a single reasoning chain with all steps. |
| | |
| | Returns: |
| | Dictionary containing: |
| | - prompt_ids: Tokenized prompt |
| | - formatted_input_ids: Tokenized chain with special tokens |
| | - formatted_text: Raw text with special tokens |
| | - step_descriptions: List of step descriptions |
| | - step_types: List of step types |
| | - step_rewards: List of process rewards |
| | - visual_features: List of visual features (if enabled) |
| | - attention_maps: List of attention maps (if available) |
| | - final_answer: Target answer |
| | - is_correct: Ground truth correctness |
| | - chain_metadata: Additional information |
| | """ |
| | chain = self.chains[idx] |
| | |
| | |
| | formatted_text = chain.format_with_tokens() |
| | |
| | |
| | formatted_encoding = self.tokenizer( |
| | formatted_text, |
| | truncation=True, |
| | max_length=2048, |
| | return_tensors='pt', |
| | ) |
| | |
| | |
| | prompt_encoding = self.tokenizer( |
| | chain.prompt, |
| | truncation=True, |
| | max_length=512, |
| | return_tensors='pt', |
| | ) |
| | |
| | |
| | step_descriptions = [] |
| | step_types = [] |
| | step_rewards = [] |
| | visual_features = [] |
| | attention_maps = [] |
| | |
| | for step in chain.steps[:self.max_steps]: |
| | step_descriptions.append(step.description) |
| | step_types.append(step.step_type.value) |
| | step_rewards.append(step.reward) |
| | |
| | if self.include_visual_features: |
| | if step.visual_features is not None: |
| | visual_features.append(step.visual_features) |
| | else: |
| | visual_features.append(torch.zeros(768)) |
| | |
| | if step.attention_map is not None: |
| | attention_maps.append(step.attention_map) |
| | |
| | |
| | answer_encoding = self.tokenizer( |
| | chain.final_answer, |
| | truncation=True, |
| | max_length=256, |
| | return_tensors='pt', |
| | ) |
| | |
| | result = { |
| | 'prompt_ids': prompt_encoding['input_ids'].squeeze(0), |
| | 'prompt_mask': prompt_encoding['attention_mask'].squeeze(0), |
| | 'formatted_input_ids': formatted_encoding['input_ids'].squeeze(0), |
| | 'formatted_attention_mask': formatted_encoding['attention_mask'].squeeze(0), |
| | 'formatted_text': formatted_text, |
| | 'step_descriptions': step_descriptions, |
| | 'step_types': step_types, |
| | 'step_rewards': torch.tensor(step_rewards, dtype=torch.float32), |
| | 'answer_ids': answer_encoding['input_ids'].squeeze(0), |
| | 'answer_mask': answer_encoding['attention_mask'].squeeze(0), |
| | 'is_correct': torch.tensor(chain.is_correct, dtype=torch.bool), |
| | 'total_reward': torch.tensor(chain.total_reward, dtype=torch.float32), |
| | 'num_steps': len(chain.steps), |
| | 'chain_id': chain.chain_id, |
| | 'image_path': chain.image_path, |
| | } |
| | |
| | if self.include_visual_features and visual_features: |
| | result['visual_features'] = torch.stack(visual_features) |
| | |
| | if attention_maps: |
| | result['attention_maps'] = attention_maps |
| | |
| | return result |
| | |
| | @classmethod |
| | def from_json_dir( |
| | cls, |
| | json_dir: str, |
| | tokenizer: Any, |
| | **kwargs |
| | ) -> 'StepDataset': |
| | """ |
| | Load dataset from directory of JSON files. |
| | |
| | Args: |
| | json_dir: Directory containing reasoning chain JSON files |
| | tokenizer: Tokenizer for text encoding |
| | **kwargs: Additional arguments for StepDataset |
| | """ |
| | from pathlib import Path |
| | |
| | json_path = Path(json_dir) |
| | chains = [] |
| | |
| | for json_file in json_path.glob('*.json'): |
| | try: |
| | chain = ReasoningChain.load_json(str(json_file)) |
| | chains.append(chain) |
| | except Exception as e: |
| | print(f"Error loading {json_file}: {e}") |
| | |
| | return cls(chains, tokenizer, **kwargs) |
| | |
| | def get_statistics(self) -> Dict[str, Any]: |
| | """Calculate dataset statistics.""" |
| | step_types = {} |
| | rewards = [] |
| | correct_count = 0 |
| | |
| | for chain in self.chains: |
| | if chain.is_correct: |
| | correct_count += 1 |
| | |
| | for step in chain.steps: |
| | step_type = step.step_type.value |
| | step_types[step_type] = step_types.get(step_type, 0) + 1 |
| | rewards.append(step.reward) |
| | |
| | return { |
| | 'num_chains': len(self.chains), |
| | 'total_steps': self.total_steps, |
| | 'avg_steps_per_chain': self.avg_steps, |
| | 'accuracy': correct_count / len(self.chains) if self.chains else 0, |
| | 'step_type_distribution': step_types, |
| | 'avg_step_reward': sum(rewards) / len(rewards) if rewards else 0, |
| | 'min_step_reward': min(rewards) if rewards else 0, |
| | 'max_step_reward': max(rewards) if rewards else 0, |
| | } |
| |
|