""" Data preparation utilities for step-level reasoning chains. Provides tools to: 1. Create reasoning chain annotations from existing VQA/reasoning datasets 2. Generate synthetic step-level annotations 3. Validate and convert data formats """ import json import argparse from pathlib import Path from typing import List, Dict, Any, Optional import logging from PIL import Image from tqdm import tqdm from src.reasoning.step_data import ReasoningChain, ReasoningStep, StepType logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class ReasoningDataPreparator: """Prepares step-level reasoning data from various sources.""" def __init__(self, output_dir: str): """ Initialize data preparator. Args: output_dir: Directory to save prepared data """ self.output_dir = Path(output_dir) self.output_dir.mkdir(parents=True, exist_ok=True) # Create split directories for split in ['train', 'val', 'test']: (self.output_dir / split).mkdir(exist_ok=True) def from_vqa_dataset( self, vqa_data: List[Dict[str, Any]], split: str = 'train', generate_steps: bool = True, ) -> List[ReasoningChain]: """ Convert VQA dataset to reasoning chains. Args: vqa_data: List of VQA samples with keys: image_path, question, answer split: Dataset split generate_steps: If True, generate intermediate reasoning steps Returns: List of reasoning chains """ chains = [] logger.info(f"Converting {len(vqa_data)} VQA samples to reasoning chains") for idx, sample in enumerate(tqdm(vqa_data)): chain_id = f"{split}_vqa_{idx:06d}" # Generate reasoning steps if generate_steps: steps = self._generate_steps_for_vqa( question=sample['question'], answer=sample['answer'], image_path=sample.get('image_path', ''), ) else: # Single step (direct answer) steps = [ ReasoningStep( step_id=0, step_type=StepType.INFERENCE, description=f"Answer: {sample['answer']}", confidence=1.0, reward=1.0, ) ] # Calculate total reward total_reward = sum(step.reward for step in steps) # Create chain chain = ReasoningChain( chain_id=chain_id, image_path=sample.get('image_path', ''), prompt=sample['question'], steps=steps, final_answer=sample['answer'], total_reward=total_reward, is_correct=True, # Assume correct if from ground truth metadata={ 'source': 'vqa', 'original_sample_id': sample.get('id', idx), } ) chains.append(chain) # Save chain save_path = self.output_dir / split / f"{chain_id}.json" chain.save_json(str(save_path)) logger.info(f"Saved {len(chains)} reasoning chains to {self.output_dir / split}") return chains def _generate_steps_for_vqa( self, question: str, answer: str, image_path: str = '', ) -> List[ReasoningStep]: """ Generate plausible reasoning steps for VQA question. This is a heuristic approach. For production, use a language model or manual annotation. Args: question: VQA question answer: Ground truth answer image_path: Path to image Returns: List of reasoning steps """ steps = [] # Step 1: Perception (understand the question) steps.append(ReasoningStep( step_id=0, step_type=StepType.PERCEPTION, description=f"I need to answer: {question}", confidence=0.9, reward=0.7, dependencies=[], )) # Step 2: Type-specific reasoning question_lower = question.lower() if any(word in question_lower for word in ['how many', 'count']): # Counting question steps.append(ReasoningStep( step_id=1, step_type=StepType.LOCALIZATION, description="I identify and locate the relevant objects", confidence=0.85, reward=0.75, dependencies=[0], )) steps.append(ReasoningStep( step_id=2, step_type=StepType.COUNTING, description=f"I count the objects and determine the answer is {answer}", confidence=0.9, reward=0.85, dependencies=[1], )) elif any(word in question_lower for word in ['where', 'location']): # Localization question steps.append(ReasoningStep( step_id=1, step_type=StepType.LOCALIZATION, description=f"I determine the location: {answer}", confidence=0.88, reward=0.8, dependencies=[0], )) elif any(word in question_lower for word in ['what color', 'which color']): # Color perception steps.append(ReasoningStep( step_id=1, step_type=StepType.PERCEPTION, description=f"I identify the color as {answer}", confidence=0.92, reward=0.85, dependencies=[0], )) elif any(word in question_lower for word in ['compare', 'difference', 'similar']): # Comparison steps.append(ReasoningStep( step_id=1, step_type=StepType.COMPARISON, description="I compare the relevant elements", confidence=0.8, reward=0.75, dependencies=[0], )) steps.append(ReasoningStep( step_id=2, step_type=StepType.INFERENCE, description=f"Based on comparison, the answer is {answer}", confidence=0.85, reward=0.8, dependencies=[1], )) else: # General inference steps.append(ReasoningStep( step_id=1, step_type=StepType.INFERENCE, description=f"Based on the image, I conclude: {answer}", confidence=0.85, reward=0.8, dependencies=[0], )) # Final verification step steps.append(ReasoningStep( step_id=len(steps), step_type=StepType.VERIFICATION, description=f"I verify my answer: {answer}", confidence=0.9, reward=0.85, dependencies=[len(steps) - 1], )) return steps def from_gqa_dataset( self, gqa_data: List[Dict[str, Any]], split: str = 'train', ) -> List[ReasoningChain]: """ Convert GQA dataset (which has semantic parse) to reasoning chains. GQA provides structured programs that can be converted to steps. Args: gqa_data: List of GQA samples split: Dataset split Returns: List of reasoning chains """ chains = [] logger.info(f"Converting {len(gqa_data)} GQA samples to reasoning chains") for idx, sample in enumerate(tqdm(gqa_data)): # GQA has semantic_parse which can map to reasoning steps steps = [] if 'semantic' in sample: # Parse semantic structure into steps for step_idx, operation in enumerate(sample['semantic']): step_type = self._map_gqa_operation_to_step_type(operation) steps.append(ReasoningStep( step_id=step_idx, step_type=step_type, description=operation.get('argument', ''), confidence=0.85, reward=0.8, dependencies=[step_idx - 1] if step_idx > 0 else [], )) else: # Fallback to generated steps steps = self._generate_steps_for_vqa( sample['question'], sample['answer'], sample.get('imageId', ''), ) chain = ReasoningChain( chain_id=f"{split}_gqa_{idx:06d}", image_path=sample.get('imageId', ''), prompt=sample['question'], steps=steps, final_answer=sample['answer'], total_reward=sum(s.reward for s in steps), is_correct=True, metadata={'source': 'gqa', 'fullAnswer': sample.get('fullAnswer', '')}, ) chains.append(chain) save_path = self.output_dir / split / f"{chain.chain_id}.json" chain.save_json(str(save_path)) return chains def _map_gqa_operation_to_step_type(self, operation: Dict[str, Any]) -> StepType: """Map GQA semantic operation to step type.""" op = operation.get('operation', '').lower() if 'select' in op or 'relate' in op: return StepType.LOCALIZATION elif 'query' in op: return StepType.PERCEPTION elif 'verify' in op or 'choose' in op: return StepType.VERIFICATION elif 'and' in op or 'or' in op: return StepType.COMPOSITION else: return StepType.INFERENCE def validate_dataset(self, split: str = 'train') -> Dict[str, Any]: """ Validate reasoning chain dataset. Args: split: Dataset split to validate Returns: Validation statistics """ split_dir = self.output_dir / split chain_files = list(split_dir.glob('*.json')) logger.info(f"Validating {len(chain_files)} chains in {split} split") stats = { 'num_chains': len(chain_files), 'num_steps_total': 0, 'avg_steps_per_chain': 0, 'step_types': {}, 'errors': [], } for chain_file in tqdm(chain_files): try: chain = ReasoningChain.load_json(str(chain_file)) # Validate structure if not chain.steps: stats['errors'].append(f"{chain_file.name}: No steps") if not chain.final_answer: stats['errors'].append(f"{chain_file.name}: No final answer") # Count steps stats['num_steps_total'] += len(chain.steps) # Count step types for step in chain.steps: step_type = step.step_type.value stats['step_types'][step_type] = stats['step_types'].get(step_type, 0) + 1 except Exception as e: stats['errors'].append(f"{chain_file.name}: {str(e)}") stats['avg_steps_per_chain'] = stats['num_steps_total'] / stats['num_chains'] if stats['num_chains'] > 0 else 0 logger.info(f"Validation complete: {stats}") return stats def main(): parser = argparse.ArgumentParser(description="Prepare step-level reasoning data") parser.add_argument("--source", type=str, required=True, choices=['vqa', 'gqa', 'custom']) parser.add_argument("--input", type=str, required=True, help="Input data file (JSON)") parser.add_argument("--output_dir", type=str, required=True, help="Output directory") parser.add_argument("--split", type=str, default='train', choices=['train', 'val', 'test']) parser.add_argument("--validate", action='store_true', help="Validate after conversion") args = parser.parse_args() # Initialize preparator preparator = ReasoningDataPreparator(args.output_dir) # Load input data logger.info(f"Loading data from {args.input}") with open(args.input, 'r') as f: input_data = json.load(f) # Convert based on source if args.source == 'vqa': chains = preparator.from_vqa_dataset(input_data, split=args.split) elif args.source == 'gqa': chains = preparator.from_gqa_dataset(input_data, split=args.split) else: logger.error(f"Unsupported source: {args.source}") return logger.info(f"Converted {len(chains)} samples to reasoning chains") # Validate if requested if args.validate: stats = preparator.validate_dataset(args.split) logger.info(f"Validation results: {stats}") if __name__ == "__main__": main()