Spaces:
Build error
Build error
| """Data preparation utilities for training.""" | |
| import logging | |
| from typing import Dict, List, Optional, Any | |
| from torch.utils.data import Dataset | |
| from transformers import PreTrainedTokenizer | |
| from src.config import DataConfig | |
| logger = logging.getLogger(__name__) | |
| class SummarizationDataset(Dataset): | |
| """PyTorch Dataset for summarization training.""" | |
| def __init__( | |
| self, | |
| data: List[Dict[str, str]], | |
| tokenizer: PreTrainedTokenizer, | |
| config: Optional[DataConfig] = None, | |
| ): | |
| self.data = data | |
| self.tokenizer = tokenizer | |
| self.config = config or DataConfig() | |
| def __len__(self) -> int: | |
| return len(self.data) | |
| def __getitem__(self, idx: int) -> Dict[str, Any]: | |
| item = self.data[idx] | |
| prompt = self.config.format_prompt(item[self.config.input_column]) | |
| full_text = prompt + item[self.config.target_column] + self.tokenizer.eos_token | |
| encoding = self.tokenizer( | |
| full_text, | |
| truncation=True, | |
| max_length=self.config.max_input_length + self.config.max_target_length, | |
| padding="max_length", | |
| return_tensors="pt", | |
| ) | |
| return { | |
| "input_ids": encoding["input_ids"].squeeze(), | |
| "attention_mask": encoding["attention_mask"].squeeze(), | |
| "labels": encoding["input_ids"].squeeze(), | |
| } | |
| def create_dummy_data(num_samples: int = 10) -> List[Dict[str, str]]: | |
| """Generate dummy data for testing the training pipeline.""" | |
| samples = [] | |
| for i in range(num_samples): | |
| samples.append({ | |
| "text": f"This is sample document {i}. It contains information about topic {i % 3}. " | |
| f"The document discusses various aspects and provides detailed analysis. " | |
| f"Key points include methodology, results, and conclusions for study {i}.", | |
| "summary": f"Document {i} analyzes topic {i % 3}, covering methodology, results, and conclusions.", | |
| }) | |
| logger.info(f"Created {num_samples} dummy samples") | |
| return samples | |
| def prepare_dataset( | |
| data: List[Dict[str, str]], | |
| tokenizer: PreTrainedTokenizer, | |
| config: Optional[DataConfig] = None, | |
| ) -> SummarizationDataset: | |
| """Prepare dataset for training.""" | |
| return SummarizationDataset(data, tokenizer, config) | |