| import torch |
| import torch.nn as nn |
| import logging |
|
|
| logger = logging.getLogger(__name__) |
|
|
| class EarlyExitClassifier(nn.Module): |
| def __init__(self, input_dim=27, hidden_dim=128, embedding_dim=0): |
| """ |
| Args: |
| input_dim: 统计特征维度 (27) |
| hidden_dim: 隐藏层大小 |
| embedding_dim: 如果 > 0,则接收 backbone 的 hidden_states 作为输入 |
| """ |
| super().__init__() |
| |
| |
| |
| self.scalar_bn = nn.BatchNorm1d(input_dim) |
| |
| |
| self.modality_emb = nn.Embedding(2, 4) |
|
|
| |
| self.use_embedding = embedding_dim > 0 |
| if self.use_embedding: |
| |
| self.emb_proj = nn.Sequential( |
| nn.Linear(embedding_dim, 64), |
| nn.LayerNorm(64), |
| nn.ReLU() |
| ) |
| |
| total_input_dim = input_dim + 4 + 64 |
| else: |
| total_input_dim = input_dim + 4 |
|
|
| |
| self.mlp = nn.Sequential( |
| nn.Linear(total_input_dim, hidden_dim), |
| nn.BatchNorm1d(hidden_dim), |
| nn.ReLU(), |
| nn.Dropout(0.2), |
| nn.Linear(hidden_dim, 1), |
| ) |
|
|
| def forward(self, scalar_feats, modality_idx, qry_emb=None): |
| """ |
| scalar_feats: [B, 27] |
| modality_idx: [B] |
| qry_emb: [B, hidden_size] (New!) |
| """ |
| |
| s_feat = self.scalar_bn(scalar_feats) |
| |
| |
| m_feat = self.modality_emb(modality_idx) |
| |
| features = [s_feat, m_feat] |
|
|
| |
| if self.use_embedding: |
| if qry_emb is None: |
| raise ValueError("Model initialized with embedding_dim > 0 but qry_emb is None") |
| e_feat = self.emb_proj(qry_emb) |
| features.append(e_feat) |
| |
| |
| x = torch.cat(features, dim=1) |
| logits = self.mlp(x) |
| return logits |