| """ |
| AIFinder Neural Network |
| Single-headed MLP: predicts provider only. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
|
|
|
|
| class AIFinderNet(nn.Module): |
| """Single-headed classifier: predicts provider only.""" |
|
|
| def __init__( |
| self, |
| input_dim, |
| num_providers, |
| hidden_dim=1024, |
| embed_dim=256, |
| dropout=0.3, |
| ): |
| super().__init__() |
| self.backbone = nn.Sequential( |
| nn.Linear(input_dim, hidden_dim), |
| nn.BatchNorm1d(hidden_dim), |
| nn.ReLU(), |
| nn.Dropout(dropout), |
| nn.Linear(hidden_dim, embed_dim), |
| nn.BatchNorm1d(embed_dim), |
| nn.ReLU(), |
| nn.Dropout(dropout), |
| ) |
| self.provider_head = nn.Linear(embed_dim, num_providers) |
|
|
| def forward(self, x): |
| h = self.backbone(x) |
| provider_logits = self.provider_head(h) |
| return provider_logits |
|
|