| import logging |
| import random |
| from collections.abc import Iterator |
|
|
| import numpy as np |
| import pandas as pd |
| import torch |
|
|
| from src.data.batch_composer import BatchComposer, ComposedDataset |
| from src.data.containers import BatchTimeSeriesContainer |
| from src.data.frequency import parse_frequency |
| from src.gift_eval.constants import ALL_DATASETS |
| from src.gift_eval.data import Dataset as GiftEvalDataset |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class GiftEvalDataLoader: |
| """ |
| Data loader for GIFT-eval datasets, converting them to BatchTimeSeriesContainer format. |
| Supports both training and validation modes. |
| """ |
|
|
| TERMS = ["short", "medium", "long"] |
|
|
| def __init__( |
| self, |
| mode: str = "train", |
| batch_size: int = 32, |
| device: torch.device | None = None, |
| shuffle: bool = True, |
| to_univariate: bool = False, |
| max_context_length: int | None = None, |
| max_windows: int = 20, |
| skip_datasets_with_nans: bool = False, |
| datasets_to_use: list[str] | None = None, |
| dataset_storage_path: str | None = None, |
| ): |
| """ |
| Initialize GIFT-eval data loader. |
| |
| Args: |
| mode: Either "train" or "validation" |
| batch_size: Number of samples per batch |
| device: Device to load data to |
| shuffle: Whether to shuffle data |
| to_univariate: Whether to convert multivariate data to multiple univariate series |
| max_context_length: Optional maximum total window length (context + forecast) to prevent memory issues |
| max_windows: Number of windows to use for training/validation |
| skip_datasets_with_nans: Whether to skip datasets/series that contain NaN values |
| datasets_to_use: Optional list of dataset names to use. If None, uses all available datasets |
| dataset_storage_path: Path on disk where GIFT-eval HuggingFace datasets are stored |
| """ |
| |
| if datasets_to_use is not None and len(datasets_to_use) > 0: |
| |
| invalid_datasets = [ds for ds in datasets_to_use if ds not in ALL_DATASETS] |
| if invalid_datasets: |
| logger.warning(f"Invalid datasets requested: {invalid_datasets}") |
| logger.warning(f"Available datasets: {ALL_DATASETS}") |
| |
| self.dataset_names = [ds for ds in datasets_to_use if ds in ALL_DATASETS] |
| else: |
| self.dataset_names = datasets_to_use |
| else: |
| self.dataset_names = ALL_DATASETS |
|
|
| |
| if datasets_to_use is not None and len(datasets_to_use) > 0: |
| logger.info(f"Using subset of datasets: {len(self.dataset_names)}/{len(ALL_DATASETS)} datasets") |
| logger.info(f"Selected datasets: {self.dataset_names}") |
| else: |
| logger.info(f"Using all available datasets: {len(self.dataset_names)} datasets") |
|
|
| self.terms = self.TERMS |
| self.mode = mode |
| self.batch_size = batch_size |
| self.device = device |
| self.shuffle = shuffle |
| self.to_univariate = to_univariate |
| self.max_context_length = max_context_length |
| self.skip_datasets_with_nans = skip_datasets_with_nans |
|
|
| |
| self.max_windows = max_windows |
| self.dataset_storage_path = dataset_storage_path |
|
|
| |
| self._load_datasets() |
|
|
| |
| self._current_idx = 0 |
| self._epoch_data = [] |
| self._prepare_epoch_data() |
|
|
| def _load_datasets(self) -> None: |
| """Load all specified GIFT-eval datasets.""" |
| self.datasets = {} |
| self.dataset_prediction_lengths = {} |
|
|
| for dataset_name in self.dataset_names: |
| if dataset_name.startswith("m4_"): |
| max_windows = 1 |
| else: |
| max_windows = self.max_windows |
| try: |
| |
| |
| temp_dataset = GiftEvalDataset( |
| name=dataset_name, |
| term=self.terms[0], |
| to_univariate=False, |
| max_windows=max_windows, |
| storage_path=self.dataset_storage_path, |
| ) |
|
|
| |
| to_univariate = self.to_univariate and temp_dataset.target_dim > 1 |
|
|
| |
| for term in self.terms: |
| dataset_key = f"{dataset_name}_{term}" |
| dataset = GiftEvalDataset( |
| name=dataset_name, |
| term=term, |
| to_univariate=to_univariate, |
| max_windows=max_windows, |
| storage_path=self.dataset_storage_path, |
| ) |
|
|
| self.datasets[dataset_key] = dataset |
| self.dataset_prediction_lengths[dataset_key] = dataset.prediction_length |
|
|
| logger.info( |
| f"Loaded {dataset_key} - prediction_length: {dataset.prediction_length}, " |
| f"frequency: {dataset.freq}, target_dim: {dataset.target_dim}, " |
| f"min_length: {dataset._min_series_length}, windows: {dataset.windows}" |
| ) |
|
|
| except Exception as e: |
| logger.warning(f"Failed to load dataset {dataset_name}: {str(e)}") |
| continue |
|
|
| def _contains_nan(self, data_entry: dict) -> bool: |
| """Check if a data entry contains NaN values.""" |
| target = data_entry.get("target") |
| if target is None: |
| return False |
|
|
| |
| try: |
| target_np = np.asarray(target, dtype=np.float32) |
| return np.isnan(target_np).any() |
| except Exception: |
| logger.warning("NaN check: failed to coerce target to float32; skipping entry") |
| return True |
|
|
| def _convert_to_container( |
| self, data_entries: list[dict], prediction_length: int, dataset_freq: str |
| ) -> BatchTimeSeriesContainer: |
| """Convert a batch of data entries to BatchTimeSeriesContainer format with fixed future length.""" |
| batch_size = len(data_entries) |
| max_history_len = 0 |
|
|
| |
| for entry in data_entries: |
| target = np.asarray(entry["target"], dtype=np.float32) |
| if target.ndim == 1: |
| target = target.reshape(1, -1) |
|
|
| _, seq_len = target.shape |
|
|
| |
| effective_max_context = self.max_context_length if self.max_context_length is not None else seq_len |
| if seq_len > effective_max_context: |
| seq_len = effective_max_context |
|
|
| |
| history_len = max(0, min(seq_len, effective_max_context) - prediction_length) |
| max_history_len = max(max_history_len, history_len) |
|
|
| |
| first_target = np.asarray(data_entries[0]["target"], dtype=np.float32) |
| if first_target.ndim == 1: |
| |
| first_target = first_target.reshape(1, -1) |
| num_channels = first_target.shape[0] |
|
|
| |
| history_values = np.full((batch_size, max_history_len, num_channels), np.nan, dtype=np.float32) |
| future_values = np.full((batch_size, prediction_length, num_channels), np.nan, dtype=np.float32) |
| history_mask = np.zeros((batch_size, max_history_len), dtype=bool) |
|
|
| |
| for i, entry in enumerate(data_entries): |
| target = np.asarray(entry["target"], dtype=np.float32) |
| if target.ndim == 1: |
| target = target.reshape(1, -1) |
|
|
| |
| full_seq_len = target.shape[1] |
| total_len_allowed = self.max_context_length if self.max_context_length is not None else full_seq_len |
| total_len_for_entry = min(full_seq_len, total_len_allowed) |
|
|
| if total_len_for_entry < prediction_length + 1: |
| |
| raise ValueError("Entry too short after max_context_length truncation to form history+future window") |
|
|
| truncated = target[:, -total_len_for_entry:] |
| cur_history_len = total_len_for_entry - prediction_length |
|
|
| hist = truncated[:, :cur_history_len] |
| fut = truncated[:, cur_history_len : cur_history_len + prediction_length] |
|
|
| |
| history_values[i, :cur_history_len, :] = hist.T |
| future_values[i, :, :] = fut.T |
| history_mask[i, :cur_history_len] = True |
|
|
| |
| start_timestamp = data_entries[0]["start"] |
| if hasattr(start_timestamp, "to_timestamp"): |
| start_numpy = start_timestamp.to_timestamp().to_numpy() |
| else: |
| start_numpy = pd.Timestamp(start_timestamp).to_numpy() |
| start_list = [start_numpy for _ in range(batch_size)] |
|
|
| |
| frequency_enum = parse_frequency(dataset_freq) |
| frequency_list = [frequency_enum for _ in range(batch_size)] |
|
|
| |
| return BatchTimeSeriesContainer( |
| history_values=torch.tensor(history_values, dtype=torch.float32), |
| future_values=torch.tensor(future_values, dtype=torch.float32), |
| start=start_list, |
| frequency=frequency_list, |
| history_mask=torch.tensor(history_mask, dtype=torch.bool) if self.mode == "train" else None, |
| ) |
|
|
| def _prepare_epoch_data(self) -> None: |
| """Prepare all batches for one epoch.""" |
| self._epoch_data = [] |
|
|
| for dataset_key, dataset in self.datasets.items(): |
| try: |
| |
| if self.mode == "train": |
| data = dataset.training_dataset |
| else: |
| data = dataset.validation_dataset |
|
|
| |
| valid_entries = [] |
| dataset_freq = dataset.freq |
| prediction_length = self.dataset_prediction_lengths[dataset_key] |
|
|
| for entry in data: |
| |
| if self.skip_datasets_with_nans and self._contains_nan(entry): |
| continue |
|
|
| |
| target = np.asarray(entry["target"]) |
| if target.ndim == 1: |
| seq_len = len(target) |
| else: |
| seq_len = target.shape[1] |
|
|
| |
| if self.mode == "train" and seq_len < prediction_length + 1: |
| continue |
|
|
| valid_entries.append(entry) |
|
|
| if not valid_entries: |
| logger.warning(f"No valid entries found for {dataset_key}") |
| continue |
|
|
| |
| for i in range(0, len(valid_entries), self.batch_size): |
| batch_entries = valid_entries[i : i + self.batch_size] |
| try: |
| batch_container = self._convert_to_container(batch_entries, prediction_length, dataset_freq) |
| self._epoch_data.append((dataset_key, batch_container)) |
| except Exception as e: |
| logger.warning(f"Failed to create batch for {dataset_key}: {str(e)}") |
| continue |
|
|
| except Exception as e: |
| logger.warning( |
| f"Failed to process dataset {dataset_key}: {str(e)}. " |
| f"Dataset may be too short for the required offset." |
| ) |
| continue |
|
|
| |
| if self.mode == "train" and self.shuffle: |
| random.shuffle(self._epoch_data) |
|
|
| logger.info(f"Prepared {len(self._epoch_data)} batches for {self.mode} mode") |
|
|
| def __iter__(self) -> Iterator[BatchTimeSeriesContainer]: |
| """Iterate through batches for one epoch.""" |
| |
| self._current_idx = 0 |
|
|
| |
| if self.mode == "train" and self.shuffle: |
| random.shuffle(self._epoch_data) |
|
|
| return self |
|
|
| def __next__(self) -> BatchTimeSeriesContainer: |
| """Get next batch.""" |
| if not self._epoch_data: |
| raise StopIteration("No valid data available") |
|
|
| |
| if self._current_idx >= len(self._epoch_data): |
| raise StopIteration |
|
|
| |
| dataset_key, batch = self._epoch_data[self._current_idx] |
| self._current_idx += 1 |
|
|
| |
| if self.device is not None: |
| batch.to_device(self.device) |
|
|
| return batch |
|
|
| def __len__(self) -> int: |
| """Return number of batches per epoch.""" |
| return len(self._epoch_data) |
|
|
|
|
| class CyclicGiftEvalDataLoader: |
| """ |
| Wrapper for GiftEvalDataLoader that provides cycling behavior for training. |
| This allows training for a fixed number of iterations per epoch, cycling through |
| the available data as needed. |
| """ |
|
|
| def __init__(self, base_loader: GiftEvalDataLoader, num_iterations_per_epoch: int): |
| """ |
| Initialize the cyclic data loader. |
| |
| Args: |
| base_loader: The underlying GiftEvalDataLoader |
| num_iterations_per_epoch: Number of iterations to run per epoch |
| """ |
| self.base_loader = base_loader |
| self.num_iterations_per_epoch = num_iterations_per_epoch |
| self.dataset_names = base_loader.dataset_names |
| self.device = base_loader.device |
|
|
| def __iter__(self) -> Iterator[BatchTimeSeriesContainer]: |
| """Iterate for exactly num_iterations_per_epoch iterations.""" |
| self._current_iteration = 0 |
| self._base_iter = iter(self.base_loader) |
| return self |
|
|
| def __next__(self) -> BatchTimeSeriesContainer: |
| """Get next batch, cycling through base loader as needed.""" |
| if self._current_iteration >= self.num_iterations_per_epoch: |
| raise StopIteration |
|
|
| try: |
| batch = next(self._base_iter) |
| except StopIteration: |
| |
| self._base_iter = iter(self.base_loader) |
| batch = next(self._base_iter) |
|
|
| self._current_iteration += 1 |
| return batch |
|
|
| def __len__(self) -> int: |
| """Return the configured number of iterations per epoch.""" |
| return self.num_iterations_per_epoch |
|
|
|
|
| def create_synthetic_dataloader( |
| base_data_dir: str, |
| batch_size: int = 128, |
| num_batches_per_epoch: int = 1000, |
| generator_proportions: dict[str, float] | None = None, |
| mixed_batches: bool = True, |
| augmentations: dict[str, bool] | None = None, |
| augmentation_probabilities: dict[str, float] | None = None, |
| device: torch.device | None = None, |
| num_workers: int = 0, |
| pin_memory: bool = True, |
| global_seed: int = 42, |
| nan_stats_path: str | None = None, |
| nan_patterns_path: str | None = None, |
| chosen_scaler_name: str | None = None, |
| ) -> torch.utils.data.DataLoader: |
| """ |
| Create a PyTorch DataLoader for training with saved generator batches. |
| |
| Args: |
| base_data_dir: Base directory containing generator subdirectories |
| batch_size: Size of each training batch |
| num_batches_per_epoch: Number of batches per epoch |
| generator_proportions: Dict mapping generator names to proportions |
| mixed_batches: Whether to create mixed or uniform batches |
| augmentations: Dict mapping augmentation names to booleans |
| augmentation_probabilities: Dict mapping augmentation names to probabilities |
| device: Target device |
| num_workers: Number of DataLoader workers |
| pin_memory: Whether to pin memory |
| global_seed: Global random seed |
| nan_stats_path: Path to nan stats file |
| chosen_scaler_name: Name of the scaler that used in training |
| |
| Returns: |
| PyTorch DataLoader |
| """ |
|
|
| |
| composer = BatchComposer( |
| base_data_dir=base_data_dir, |
| generator_proportions=generator_proportions, |
| mixed_batches=mixed_batches, |
| device=device, |
| augmentations=augmentations, |
| augmentation_probabilities=augmentation_probabilities, |
| global_seed=global_seed, |
| nan_stats_path=nan_stats_path, |
| nan_patterns_path=nan_patterns_path, |
| chosen_scaler_name=chosen_scaler_name, |
| ) |
|
|
| |
| dataset = ComposedDataset( |
| batch_composer=composer, |
| num_batches_per_epoch=num_batches_per_epoch, |
| batch_size=batch_size, |
| ) |
|
|
| |
| def collate_fn(batch): |
| """Custom collate function that returns a single BatchTimeSeriesContainer.""" |
| |
| |
| return batch[0] |
|
|
| |
| dataloader = torch.utils.data.DataLoader( |
| dataset, |
| batch_size=1, |
| shuffle=False, |
| num_workers=num_workers, |
| pin_memory=pin_memory, |
| collate_fn=collate_fn, |
| drop_last=False, |
| ) |
|
|
| logger.info( |
| f"Created DataLoader with {len(dataset)} batches per epoch, " |
| f"batch_size={batch_size}, mixed_batches={mixed_batches}" |
| ) |
|
|
| return dataloader |
|
|
|
|
| class SyntheticValidationDataset(torch.utils.data.Dataset): |
| """ |
| Fixed synthetic validation dataset that generates a small number of batches |
| using the same composition approach as training data. |
| """ |
|
|
| def __init__( |
| self, |
| base_data_dir: str, |
| batch_size: int = 128, |
| num_batches: int = 2, |
| future_length: int = 512, |
| generator_proportions: dict[str, float] | None = None, |
| augmentations: dict[str, bool] | None = None, |
| augmentation_probabilities: dict[str, float] | None = None, |
| device: torch.device | None = None, |
| global_seed: int = 42, |
| chosen_scaler_name: str | None = None, |
| nan_stats_path: str | None = None, |
| nan_patterns_path: str | None = None, |
| rank: int = 0, |
| world_size: int = 1, |
| ): |
| """ |
| Initialize the validation dataset. |
| |
| Args: |
| base_data_dir: Base directory containing generator subdirectories |
| batch_size: Size of each validation batch |
| num_batches: Number of validation batches to generate (1 or 2) |
| generator_proportions: Dict mapping generator names to proportions |
| device: Device to load tensors to |
| global_seed: Global random seed |
| chosen_scaler_name: Name of the scaler that used in training |
| """ |
| self.batch_size = batch_size |
| self.num_batches = num_batches |
| self.device = device |
|
|
| |
| val_augmentations = dict(augmentations or {}) |
| val_augmentations["length_shortening"] = False |
|
|
| self.batch_composer = BatchComposer( |
| base_data_dir=base_data_dir, |
| generator_proportions=generator_proportions, |
| mixed_batches=True, |
| device=device, |
| global_seed=global_seed + 999999, |
| augmentations=val_augmentations, |
| augmentation_probabilities=augmentation_probabilities, |
| nan_stats_path=nan_stats_path, |
| nan_patterns_path=nan_patterns_path, |
| chosen_scaler_name=chosen_scaler_name, |
| rank=rank, |
| world_size=world_size, |
| ) |
|
|
| |
| self.validation_batches = [] |
| for i in range(num_batches): |
| batch, _ = self.batch_composer.create_batch( |
| batch_size=batch_size, |
| future_length=future_length, |
| seed=global_seed + 999999 + i, |
| ) |
| self.validation_batches.append(batch) |
|
|
| logger.info(f"Created {num_batches} fixed validation batches with batch_size={batch_size}") |
|
|
| def __len__(self) -> int: |
| return self.num_batches |
|
|
| def __getitem__(self, idx: int) -> BatchTimeSeriesContainer: |
| """ |
| Get a pre-generated validation batch by index. |
| |
| Args: |
| idx: Batch index |
| |
| Returns: |
| BatchTimeSeriesContainer |
| """ |
| if idx >= len(self.validation_batches): |
| raise IndexError(f"Batch index {idx} out of range") |
|
|
| batch = self.validation_batches[idx] |
|
|
| |
| if self.device is not None: |
| batch.to_device(self.device) |
|
|
| return batch |
|
|
|
|
| def create_synthetic_dataset( |
| base_data_dir: str, |
| batch_size: int = 128, |
| num_batches_per_epoch: int = 1000, |
| generator_proportions: dict[str, float] | None = None, |
| mixed_batches: bool = True, |
| augmentations: dict[str, bool] | None = None, |
| augmentation_probabilities: dict[str, float] | None = None, |
| global_seed: int = 42, |
| nan_stats_path: str | None = None, |
| nan_patterns_path: str | None = None, |
| chosen_scaler_name: str | None = None, |
| rank: int = 0, |
| world_size: int = 1, |
| ) -> ComposedDataset: |
| """ |
| Creates the ComposedDataset for training with saved generator batches. |
| |
| Args: |
| base_data_dir: Base directory containing generator subdirectories. |
| batch_size: Size of each training batch. |
| num_batches_per_epoch: Number of batches per epoch. |
| generator_proportions: Dict mapping generator names to proportions. |
| mixed_batches: Whether to create mixed or uniform batches. |
| augmentations: Dict mapping augmentation names to booleans. |
| global_seed: Global random seed. |
| nan_stats_path: Path to nan stats file. |
| chosen_scaler_name: Name of the scaler to use. |
| Returns: |
| A ComposedDataset instance. |
| """ |
| |
| composer = BatchComposer( |
| base_data_dir=base_data_dir, |
| generator_proportions=generator_proportions, |
| mixed_batches=mixed_batches, |
| device=None, |
| augmentations=augmentations, |
| augmentation_probabilities=augmentation_probabilities, |
| global_seed=global_seed, |
| nan_stats_path=nan_stats_path, |
| nan_patterns_path=nan_patterns_path, |
| chosen_scaler_name=chosen_scaler_name, |
| rank=rank, |
| world_size=world_size, |
| ) |
|
|
| |
| dataset = ComposedDataset( |
| batch_composer=composer, |
| num_batches_per_epoch=num_batches_per_epoch, |
| batch_size=batch_size, |
| ) |
|
|
| logger.info( |
| f"Created ComposedDataset with {len(dataset)} batches per epoch, " |
| f"batch_size={batch_size}, mixed_batches={mixed_batches}" |
| ) |
|
|
| return dataset |
|
|