| | import torch |
| | import torch.nn as nn |
| | from torch.utils.data import Sampler |
| | from collections import defaultdict |
| | import random |
| | import logging |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | class EarlyExitClassifier(nn.Module): |
| | def __init__(self, input_dim=5, hidden_dim=64): |
| | """ |
| | input_dim=5: [Top1_Score, Margin, Entropy, Norm, Variance] |
| | """ |
| | super().__init__() |
| | |
| | self.modality_emb = nn.Embedding(2, 4) |
| | |
| | total_input_dim = input_dim + 4 |
| | |
| | self.mlp = nn.Sequential( |
| | nn.Linear(total_input_dim, hidden_dim), |
| | nn.BatchNorm1d(hidden_dim), |
| | nn.ReLU(), |
| | nn.Linear(hidden_dim, 1), |
| | nn.Sigmoid() |
| | ) |
| |
|
| | def forward(self, scalar_feats, modality_idx): |
| | mod_feat = self.modality_emb(modality_idx) |
| | x = torch.cat([scalar_feats, mod_feat], dim=1) |
| | return self.mlp(x) |
| |
|
| | class HomogeneousBatchSampler(Sampler): |
| | def __init__(self, dataset, batch_size, drop_last=False): |
| | self.dataset = dataset |
| | self.batch_size = batch_size |
| | self.drop_last = drop_last |
| | self.groups = defaultdict(list) |
| |
|
| | logger.info("Grouping data by dataset source for Homogeneous Sampling...") |
| | try: |
| | |
| | if hasattr(dataset, 'datasets'): |
| | |
| | current_idx = 0 |
| | for sub_ds in dataset.datasets: |
| | |
| | |
| | |
| | |
| | |
| | pass |
| | |
| | |
| | |
| | |
| | for idx in range(len(dataset)): |
| | item = dataset[idx] |
| | d_name = item.get('global_dataset_name', 'unknown') |
| | self.groups[d_name].append(idx) |
| | |
| | except Exception as e: |
| | logger.warning(f"Error grouping dataset: {e}. Falling back to simple index chunking (NOT HOMOGENEOUS).") |
| | self.groups['all'] = list(range(len(dataset))) |
| |
|
| | logger.info(f"Grouped data into {len(self.groups)} datasets.") |
| |
|
| | def __iter__(self): |
| | batch_list = [] |
| | for d_name, indices in self.groups.items(): |
| | random.shuffle(indices) |
| | for i in range(0, len(indices), self.batch_size): |
| | batch = indices[i : i + self.batch_size] |
| | if len(batch) < self.batch_size and self.drop_last: |
| | continue |
| | if len(batch) < 2: |
| | continue |
| | batch_list.append(batch) |
| | |
| | random.shuffle(batch_list) |
| | for batch in batch_list: |
| | yield batch |
| |
|
| | def __len__(self): |
| | count = 0 |
| | for indices in self.groups.values(): |
| | if self.drop_last: |
| | count += len(indices) // self.batch_size |
| | else: |
| | remainder = len(indices) % self.batch_size |
| | full = len(indices) // self.batch_size |
| | count += full + (1 if remainder >= 2 else 0) |
| | return count |