File size: 16,357 Bytes
2b8876a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3024a91
2b8876a
 
 
 
 
 
 
 
3024a91
 
 
 
 
2b8876a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
"""
Data structures for step-level reasoning chains with vision-language annotations.
"""

from dataclasses import dataclass, field
from typing import Dict, List, Optional, Any
from enum import Enum
import torch
from torch.utils.data import Dataset
import json

# Special tokens for chain-of-thought reasoning
SPECIAL_TOKENS = [
    # Reasoning structure
    "<|reasoning_start|>", "<|reasoning_end|>",
    
    # Step-level markers
    "<|step_start|>", "<|step_end|>",
    "<|step_type|>",
    
    # Step components
    "<|description_start|>", "<|description_end|>",
    "<|confidence_start|>", "<|confidence_end|>",
    
    # Dependencies
    "<|depends_on|>",
    
    # PRM evaluation tokens (from PRM800K)
    "ки",   # step separator for PRM scoring
    "к+и",  # positive step label
    "к-и",  # negative step label
    
    # Final answer
    "<|answer_start|>", "<|answer_end|>",
    
    # Region tokens (inspired by Chain-of-Sight for visual grounding)
    "<|region|>", "<|bbox|>",  # Region markers and bounding box coordinates
]

SPECIAL_TOKEN_DESCRIPTIONS = [
    "start of reasoning", "end of reasoning",
    "start of step", "end of step",
    "step type indicator",
    "start of description", "end of description",
    "start of confidence", "end of confidence",
    "depends on previous step",
    "step separator for PRM", "positive step", "negative step",
    "start of answer", "end of answer",
    "region reference", "bounding box coordinates",
]

# Generate additional region tokens like <region1>, <region2>, ... <region10>
# for referring to specific visual regions (Chain-of-Sight style)
REGION_TOKENS = [f"<region{i}>" for i in range(1, 11)]
SPECIAL_TOKENS.extend(REGION_TOKENS)
SPECIAL_TOKEN_DESCRIPTIONS.extend([f"visual region {i}" for i in range(1, 11)])


class StepType(Enum):
    """Types of reasoning steps in vision-language tasks."""
    PERCEPTION = "perception"  # Visual understanding
    LOCALIZATION = "localization"  # Spatial reasoning
    COMPARISON = "comparison"  # Comparing visual elements
    COUNTING = "counting"  # Numerical reasoning
    INFERENCE = "inference"  # Logical deduction
    COMPOSITION = "composition"  # Combining information
    VERIFICATION = "verification"  # Checking consistency


@dataclass
class ReasoningStep:
    """
    A single step in the chain of thought reasoning process.
    
    Attributes:
        step_id: Unique identifier for this step
        step_type: Type of reasoning performed
        description: Natural language description of the reasoning
        visual_features: Extracted visual features relevant to this step
        attention_map: Spatial attention weights (H, W)
        confidence: Model's confidence in this step [0, 1]
        reward: Process reward score assigned by PRM [-1, 1]
        dependencies: IDs of previous steps this depends on
        bounding_boxes: List of [x1, y1, x2, y2] bounding boxes (Chain-of-Sight style)
        region_ids: List of region IDs referenced in this step (e.g., [1, 2])
        metadata: Additional step-specific information
    """
    step_id: int
    step_type: StepType
    description: str
    visual_features: Optional[torch.Tensor] = None
    attention_map: Optional[torch.Tensor] = None
    confidence: float = 0.0
    reward: float = 0.0
    dependencies: List[int] = field(default_factory=list)
    bounding_boxes: List[List[float]] = field(default_factory=list)  # [[x1,y1,x2,y2], ...]
    region_ids: List[int] = field(default_factory=list)  # [1, 2, 3]
    metadata: Dict[str, Any] = field(default_factory=dict)
    
    def to_dict(self) -> Dict[str, Any]:
        """Convert to dictionary for serialization."""
        return {
            'step_id': self.step_id,
            'step_type': self.step_type.value,
            'description': self.description,
            'confidence': self.confidence,
            'reward': self.reward,
            'dependencies': self.dependencies,
            'metadata': self.metadata,
            'has_visual_features': self.visual_features is not None,
            'has_attention_map': self.attention_map is not None,
        }
    
    def format_with_tokens(self) -> str:
        """Format step with special tokens for training (with Chain-of-Sight region support)."""
        parts = ["<|step_start|>"]
        
        # Add step type
        parts.append(f"<|step_type|>{self.step_type.value}")
        
        # Add dependencies if any
        if self.dependencies:
            deps_str = ",".join(map(str, self.dependencies))
            parts.append(f"<|depends_on|>{deps_str}")
        
        # Add region references (Chain-of-Sight style)
        if self.region_ids:
            region_refs = " ".join([f"<region{rid}>" for rid in self.region_ids])
            parts.append(f"<|region|>{region_refs}")
        
        # Add bounding boxes if present
        if self.bounding_boxes:
            bbox_strs = [f"[{b[0]:.3f},{b[1]:.3f},{b[2]:.3f},{b[3]:.3f}]" for b in self.bounding_boxes]
            parts.append(f"<|bbox|>{','.join(bbox_strs)}")
        
        # Add description
        parts.append(f"<|description_start|>{self.description}<|description_end|>")
        
        # Add confidence
        parts.append(f"<|confidence_start|>{self.confidence:.3f}<|confidence_end|>")
        
        # Add PRM separator
        parts.append("ки")
        
        parts.append("<|step_end|>")
        
        return " ".join(parts)
    
    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> 'ReasoningStep':
        """Create from dictionary."""
        return cls(
            step_id=data['step_id'],
            step_type=StepType(data['step_type']),
            description=data['description'],
            confidence=data.get('confidence', 0.0),
            reward=data.get('reward', 0.0),
            dependencies=data.get('dependencies', []),
            metadata=data.get('metadata', {}),
        )


@dataclass
class ReasoningChain:
    """
    Complete chain of reasoning steps for a vision-language task.
    
    Attributes:
        chain_id: Unique identifier for this reasoning chain
        image_path: Path to the input image
        image: PIL Image object (for streaming datasets)
        prompt: Input text prompt/question
        steps: Ordered list of reasoning steps
        final_answer: Final output/answer
        total_reward: Cumulative reward from PRM
        is_correct: Whether final answer is correct
        metadata: Additional chain-level information
    """
    chain_id: str
    image_path: str = None
    image: Any = None
    prompt: str = ""
    steps: List[ReasoningStep] = field(default_factory=list)
    final_answer: str = ""
    total_reward: float = 0.0
    is_correct: bool = False
    metadata: Dict[str, Any] = field(default_factory=dict)
    
    def __len__(self) -> int:
        """Number of reasoning steps."""
        return len(self.steps)
    
    def get_step_rewards(self) -> List[float]:
        """Extract reward for each step."""
        return [step.reward for step in self.steps]
    
    def get_cumulative_rewards(self) -> List[float]:
        """Calculate cumulative rewards up to each step."""
        cumulative = []
        total = 0.0
        for step in self.steps:
            total += step.reward
            cumulative.append(total)
        return cumulative
    
    def to_dict(self) -> Dict[str, Any]:
        """Convert to dictionary for serialization."""
        return {
            'chain_id': self.chain_id,
            'image_path': self.image_path,
            'prompt': self.prompt,
            'steps': [step.to_dict() for step in self.steps],
            'final_answer': self.final_answer,
            'total_reward': self.total_reward,
            'is_correct': self.is_correct,
            'metadata': self.metadata
        }
    
    def format_with_tokens(self) -> str:
        """Format entire reasoning chain with special tokens for training."""
        parts = ["<|reasoning_start|>"]
        
        # Add all steps
        for step in self.steps:
            parts.append(step.format_with_tokens())
        
        parts.append("<|reasoning_end|>")
        
        # Add final answer
        parts.append(f"<|answer_start|>{self.final_answer}<|answer_end|>")
        
        return " ".join(parts)
    
    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> 'ReasoningChain':
        """Create from dictionary."""
        return cls(
            chain_id=data['chain_id'],
            image_path=data['image_path'],
            prompt=data['prompt'],
            steps=[ReasoningStep.from_dict(s) for s in data['steps']],
            final_answer=data['final_answer'],
            total_reward=data.get('total_reward', 0.0),
            is_correct=data.get('is_correct', False),
            metadata=data.get('metadata', {}),
        )
    
    def save_json(self, path: str) -> None:
        """Save reasoning chain to JSON file."""
        with open(path, 'w', encoding='utf-8') as f:
            json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)
    
    @classmethod
    def load_json(cls, path: str) -> 'ReasoningChain':
        """Load reasoning chain from JSON file."""
        with open(path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        return cls.from_dict(data)


class StepDataset(Dataset):
    """
    Dataset for training with step-level reasoning chains.
    
    Supports both supervised fine-tuning and RL training with process rewards.
    """
    
    def __init__(
        self,
        reasoning_chains: List[ReasoningChain],
        tokenizer: Any,
        max_steps: int = 10,
        include_visual_features: bool = True,
        filter_incorrect: bool = False,
    ):
        """
        Initialize step dataset.
        
        Args:
            reasoning_chains: List of annotated reasoning chains
            tokenizer: Tokenizer for text encoding
            max_steps: Maximum number of steps per chain
            include_visual_features: Whether to include visual features
            filter_incorrect: If True, only include correct chains
        """
        self.tokenizer = tokenizer
        self.max_steps = max_steps
        self.include_visual_features = include_visual_features
        
        # Filter chains if needed
        if filter_incorrect:
            self.chains = [c for c in reasoning_chains if c.is_correct]
        else:
            self.chains = reasoning_chains
        
        # Pre-compute statistics
        self.avg_steps = sum(len(c) for c in self.chains) / len(self.chains) if self.chains else 0
        self.total_steps = sum(len(c) for c in self.chains)
    
    def __len__(self) -> int:
        return len(self.chains)
    
    def __getitem__(self, idx: int) -> Dict[str, Any]:
        """
        Get a single reasoning chain with all steps.
        
        Returns:
            Dictionary containing:
                - prompt_ids: Tokenized prompt
                - formatted_input_ids: Tokenized chain with special tokens
                - formatted_text: Raw text with special tokens
                - step_descriptions: List of step descriptions
                - step_types: List of step types
                - step_rewards: List of process rewards
                - visual_features: List of visual features (if enabled)
                - attention_maps: List of attention maps (if available)
                - final_answer: Target answer
                - is_correct: Ground truth correctness
                - chain_metadata: Additional information
        """
        chain = self.chains[idx]
        
        # Format chain with special tokens
        formatted_text = chain.format_with_tokens()
        
        # Tokenize formatted text (for training)
        formatted_encoding = self.tokenizer(
            formatted_text,
            truncation=True,
            max_length=2048,  # Longer for full reasoning chains
            return_tensors='pt',
        )
        
        # Tokenize prompt (for compatibility)
        prompt_encoding = self.tokenizer(
            chain.prompt,
            truncation=True,
            max_length=512,
            return_tensors='pt',
        )
        
        # Prepare step data
        step_descriptions = []
        step_types = []
        step_rewards = []
        visual_features = []
        attention_maps = []
        
        for step in chain.steps[:self.max_steps]:
            step_descriptions.append(step.description)
            step_types.append(step.step_type.value)
            step_rewards.append(step.reward)
            
            if self.include_visual_features:
                if step.visual_features is not None:
                    visual_features.append(step.visual_features)
                else:
                    visual_features.append(torch.zeros(768))  # Default dimension
                
                if step.attention_map is not None:
                    attention_maps.append(step.attention_map)
        
        # Tokenize final answer
        answer_encoding = self.tokenizer(
            chain.final_answer,
            truncation=True,
            max_length=256,
            return_tensors='pt',
        )
        
        result = {
            'prompt_ids': prompt_encoding['input_ids'].squeeze(0),
            'prompt_mask': prompt_encoding['attention_mask'].squeeze(0),
            'formatted_input_ids': formatted_encoding['input_ids'].squeeze(0),
            'formatted_attention_mask': formatted_encoding['attention_mask'].squeeze(0),
            'formatted_text': formatted_text,
            'step_descriptions': step_descriptions,
            'step_types': step_types,
            'step_rewards': torch.tensor(step_rewards, dtype=torch.float32),
            'answer_ids': answer_encoding['input_ids'].squeeze(0),
            'answer_mask': answer_encoding['attention_mask'].squeeze(0),
            'is_correct': torch.tensor(chain.is_correct, dtype=torch.bool),
            'total_reward': torch.tensor(chain.total_reward, dtype=torch.float32),
            'num_steps': len(chain.steps),
            'chain_id': chain.chain_id,
            'image_path': chain.image_path,
        }
        
        if self.include_visual_features and visual_features:
            result['visual_features'] = torch.stack(visual_features)
        
        if attention_maps:
            result['attention_maps'] = attention_maps
        
        return result
    
    @classmethod
    def from_json_dir(
        cls,
        json_dir: str,
        tokenizer: Any,
        **kwargs
    ) -> 'StepDataset':
        """
        Load dataset from directory of JSON files.
        
        Args:
            json_dir: Directory containing reasoning chain JSON files
            tokenizer: Tokenizer for text encoding
            **kwargs: Additional arguments for StepDataset
        """
        from pathlib import Path
        
        json_path = Path(json_dir)
        chains = []
        
        for json_file in json_path.glob('*.json'):
            try:
                chain = ReasoningChain.load_json(str(json_file))
                chains.append(chain)
            except Exception as e:
                print(f"Error loading {json_file}: {e}")
        
        return cls(chains, tokenizer, **kwargs)
    
    def get_statistics(self) -> Dict[str, Any]:
        """Calculate dataset statistics."""
        step_types = {}
        rewards = []
        correct_count = 0
        
        for chain in self.chains:
            if chain.is_correct:
                correct_count += 1
            
            for step in chain.steps:
                step_type = step.step_type.value
                step_types[step_type] = step_types.get(step_type, 0) + 1
                rewards.append(step.reward)
        
        return {
            'num_chains': len(self.chains),
            'total_steps': self.total_steps,
            'avg_steps_per_chain': self.avg_steps,
            'accuracy': correct_count / len(self.chains) if self.chains else 0,
            'step_type_distribution': step_types,
            'avg_step_reward': sum(rewards) / len(rewards) if rewards else 0,
            'min_step_reward': min(rewards) if rewards else 0,
            'max_step_reward': max(rewards) if rewards else 0,
        }