Deepfake-audio-detection / model_utils.py
mariam-ahmed15's picture
Update model_utils.py
bfa508b verified
import torch
import torch.nn as nn
from transformers import Wav2Vec2Model
class ForensicDeepfakeDetector(nn.Module):
def __init__(self, model_name="facebook/wav2vec2-xls-r-300m"):
super().__init__()
# 1. Multilingual Backbone (Arabic/English)
self.backbone = Wav2Vec2Model.from_pretrained(model_name)
# Freeze to fit 8GB VRAM
for param in self.backbone.parameters():
param.requires_grad = False
# 2. Adaptive Pooling (Handles variable lengths)
self.pooling = nn.AdaptiveAvgPool1d(1)
# 3. Classifier (Outputs raw Logits)
self.classifier = nn.Sequential(
nn.Linear(1024, 512),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(512, 1) # No Sigmoid here! Stability is handled by the Loss function
)
def forward(self, x):
with torch.no_grad():
outputs = self.backbone(x).last_hidden_state
outputs = outputs.permute(0, 2, 1) # [B, 1024, Time]
pooled = self.pooling(outputs).squeeze(-1) # [B, 1024]
return self.classifier(pooled)