import torch import torch.nn as nn from transformers import Wav2Vec2Model class ForensicDeepfakeDetector(nn.Module): def __init__(self, model_name="facebook/wav2vec2-xls-r-300m"): super().__init__() # 1. Multilingual Backbone (Arabic/English) self.backbone = Wav2Vec2Model.from_pretrained(model_name) # Freeze to fit 8GB VRAM for param in self.backbone.parameters(): param.requires_grad = False # 2. Adaptive Pooling (Handles variable lengths) self.pooling = nn.AdaptiveAvgPool1d(1) # 3. Classifier (Outputs raw Logits) self.classifier = nn.Sequential( nn.Linear(1024, 512), nn.LeakyReLU(0.2), nn.Dropout(0.3), nn.Linear(512, 1) # No Sigmoid here! Stability is handled by the Loss function ) def forward(self, x): with torch.no_grad(): outputs = self.backbone(x).last_hidden_state outputs = outputs.permute(0, 2, 1) # [B, 1024, Time] pooled = self.pooling(outputs).squeeze(-1) # [B, 1024] return self.classifier(pooled)