""" Dataset loader for HuggingFaceM4/FineVision. This adapter loads FineVision image-text pairs and optionally generates synthetic reasoning chains for step-level CoT training. """ import json import re from typing import List, Dict, Optional, Any, Union from pathlib import Path import torch from torch.utils.data import Dataset from datasets import load_dataset from PIL import Image import io from .step_data import ( ReasoningChain, ReasoningStep, StepType, SPECIAL_TOKENS, ) class FineVisionDataset(Dataset): """ Dataset adapter for HuggingFaceM4/FineVision with optional synthetic reasoning chain generation. FineVision contains image-text pairs for vision-language tasks. This adapter can either use the raw data or generate synthetic reasoning chains to augment training with explicit step-by-step reasoning. """ def __init__( self, dataset_name: str = "HuggingFaceM4/FineVision", subset: Optional[str] = None, split: str = "train", streaming: bool = True, max_samples: Optional[int] = None, image_key: str = "image", text_key: str = "text", generate_synthetic_reasoning: bool = True, num_synthetic_steps: int = 3, cache_dir: Optional[str] = None, ): """ Initialize FineVision dataset adapter. Args: dataset_name: HuggingFace dataset identifier subset: Dataset subset/config name (if applicable) split: Dataset split (train, validation, test) streaming: Use streaming mode for large datasets max_samples: Maximum number of samples to load (None = all) image_key: Column name for images text_key: Column name for text/captions generate_synthetic_reasoning: If True, create synthetic reasoning steps num_synthetic_steps: Number of reasoning steps to generate cache_dir: Directory to cache downloaded data """ self.dataset_name = dataset_name self.subset = subset self.split = split self.streaming = streaming self.max_samples = max_samples self.image_key = image_key self.text_key = text_key self.generate_synthetic_reasoning = generate_synthetic_reasoning self.num_synthetic_steps = num_synthetic_steps print(f"Loading {dataset_name} ({split} split)...") # Load dataset from HuggingFace try: if streaming: self.hf_dataset = load_dataset( dataset_name, name=subset, split=split, streaming=True, cache_dir=cache_dir, ) self.hf_dataset_iter = iter(self.hf_dataset) else: self.hf_dataset = load_dataset( dataset_name, name=subset, split=split, cache_dir=cache_dir, ) if max_samples: self.hf_dataset = self.hf_dataset.select(range(min(max_samples, len(self.hf_dataset)))) except Exception as e: raise RuntimeError(f"Failed to load dataset {dataset_name}: {e}") # Pre-build reasoning chains for non-streaming mode if not streaming: print("Converting to ReasoningChain format...") self.reasoning_chains = [] for idx, sample in enumerate(self.hf_dataset): chain = self._parse_sample_to_chain(sample, idx) if chain: self.reasoning_chains.append(chain) print(f"Loaded {len(self.reasoning_chains)} reasoning chains") else: self.reasoning_chains = None def __len__(self) -> int: if self.streaming: return self.max_samples or 10**9 # Large number for streaming return len(self.reasoning_chains) def __getitem__(self, idx: int) -> Optional[ReasoningChain]: """Get a reasoning chain by index.""" if self.streaming: # For streaming, get next item from iterator try: sample = next(self.hf_dataset_iter) except StopIteration: self.hf_dataset_iter = iter(self.hf_dataset) sample = next(self.hf_dataset_iter) return self._parse_sample_to_chain(sample, idx) else: return self.reasoning_chains[idx] def _parse_sample_to_chain(self, sample: Dict[str, Any], idx: int) -> Optional[ReasoningChain]: """ Parse a FineVision sample into a ReasoningChain. Args: sample: Raw sample from HuggingFace dataset idx: Sample index Returns: ReasoningChain object or None if parsing fails """ try: # Extract fields image = sample.get(self.image_key) text = sample.get(self.text_key, "") # Convert image to PIL if needed if isinstance(image, bytes): image = Image.open(io.BytesIO(image)).convert('RGB') elif hasattr(image, 'convert'): image = image.convert('RGB') # Generate reasoning steps if self.generate_synthetic_reasoning: steps = self._generate_synthetic_steps(text, self.num_synthetic_steps) else: # Create single inference step with the text steps = [ ReasoningStep( step_id=0, step_type=StepType.PERCEPTION, description=text, confidence=0.9, ) ] # Use text as both question and answer for FineVision # (it's more like caption/description dataset) prompt = "Describe this image in detail." answer = text # Create reasoning chain chain = ReasoningChain( chain_id=f"finevision_{idx}", image_path=None, # Streaming mode, no path image=image, prompt=prompt, steps=steps, final_answer=answer, is_correct=True, # Assume dataset contains correct descriptions metadata={ "dataset": self.dataset_name, "split": self.split, "original_idx": idx, "synthetic_reasoning": self.generate_synthetic_reasoning, } ) return chain except Exception as e: print(f"Warning: Failed to parse sample {idx}: {e}") return None def _generate_synthetic_steps(self, text: str, num_steps: int) -> List[ReasoningStep]: """ Generate synthetic reasoning steps from a text description. This creates a multi-step reasoning chain by breaking down the description into perception, analysis, and conclusion steps. Args: text: Original text description num_steps: Number of steps to generate Returns: List of ReasoningStep objects """ steps = [] if num_steps <= 1: # Single step - just use the text steps.append( ReasoningStep( step_id=0, step_type=StepType.PERCEPTION, description=text, confidence=0.9, ) ) elif num_steps == 2: # Two steps: perception + inference steps.append( ReasoningStep( step_id=0, step_type=StepType.PERCEPTION, description=f"I observe the visual content of the image.", confidence=0.9, ) ) steps.append( ReasoningStep( step_id=1, step_type=StepType.INFERENCE, description=text, confidence=0.85, dependencies=[0], ) ) elif num_steps >= 3: # Three+ steps: perception + composition + inference # Split text into chunks if long enough sentences = [s.strip() for s in text.split('.') if s.strip()] # Step 1: Perception steps.append( ReasoningStep( step_id=0, step_type=StepType.PERCEPTION, description="I carefully examine the visual elements in the image.", confidence=0.9, ) ) # Step 2-N: Composition/Analysis (use text chunks or full text) if len(sentences) >= 2 and num_steps > 3: # Distribute sentences across middle steps sentences_per_step = max(1, len(sentences) // (num_steps - 2)) for i in range(num_steps - 2): start_idx = i * sentences_per_step end_idx = start_idx + sentences_per_step if i < num_steps - 3 else len(sentences) chunk_text = '. '.join(sentences[start_idx:end_idx]) if chunk_text and not chunk_text.endswith('.'): chunk_text += '.' steps.append( ReasoningStep( step_id=i + 1, step_type=StepType.COMPOSITION, description=chunk_text if chunk_text else text, confidence=0.85, dependencies=list(range(i + 1)), ) ) else: # Not enough sentences, use full text for middle step steps.append( ReasoningStep( step_id=1, step_type=StepType.COMPOSITION, description=text, confidence=0.85, dependencies=[0], ) ) # Final step: Inference/Conclusion if len(steps) < num_steps: final_text = "Based on these observations, I can provide a comprehensive description." steps.append( ReasoningStep( step_id=len(steps), step_type=StepType.INFERENCE, description=final_text, confidence=0.8, dependencies=list(range(len(steps))), ) ) return steps[:num_steps] # Ensure we don't exceed requested steps def create_finevision_dataloader( tokenizer: Any, batch_size: int = 4, num_workers: int = 4, streaming: bool = True, max_samples: Optional[int] = None, generate_synthetic_reasoning: bool = True, **kwargs ) -> torch.utils.data.DataLoader: """ Create a DataLoader for FineVision dataset. Args: tokenizer: Tokenizer for text encoding batch_size: Number of samples per batch num_workers: Number of worker processes streaming: Use streaming mode max_samples: Maximum samples to load generate_synthetic_reasoning: Generate synthetic reasoning steps **kwargs: Additional arguments for FineVisionDataset Returns: Configured DataLoader """ from .step_data import StepDataset # Create FineVision dataset finevision_dataset = FineVisionDataset( streaming=streaming, max_samples=max_samples, generate_synthetic_reasoning=generate_synthetic_reasoning, **kwargs ) # If not streaming, wrap in StepDataset for tokenization if not streaming and finevision_dataset.reasoning_chains: step_dataset = StepDataset( reasoning_chains=finevision_dataset.reasoning_chains, tokenizer=tokenizer, max_steps=10, include_visual_features=False, filter_incorrect=False, ) return torch.utils.data.DataLoader( step_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True, collate_fn=lambda batch: batch, # Custom collation handled in trainer ) else: # For streaming, return raw finevision dataset return torch.utils.data.DataLoader( finevision_dataset, batch_size=batch_size, num_workers=0 if streaming else num_workers, shuffle=False, ) # Example usage if __name__ == "__main__": # Test dataset loading print("Testing FineVision dataset adapter...") dataset = FineVisionDataset( streaming=False, max_samples=5, generate_synthetic_reasoning=True, num_synthetic_steps=3, ) print(f"\nLoaded {len(dataset)} samples") # Show first example if len(dataset) > 0: chain = dataset[0] print(f"\n{'='*60}") print("Example Reasoning Chain:") print(f"{'='*60}") print(f"Chain ID: {chain.chain_id}") print(f"Prompt: {chain.prompt}") print(f"Number of steps: {len(chain.steps)}") print(f"\nSteps:") for step in chain.steps: print(f" Step {step.step_id} ({step.step_type.value}): {step.description[:100]}...") print(f"\nFinal Answer: {chain.final_answer[:200]}...") print(f"\n{'='*60}") print("\nFormatted with special tokens:") print(f"{'='*60}") formatted = chain.format_with_tokens() print(formatted[:500] + "..." if len(formatted) > 500 else formatted)