|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import Wav2Vec2Model |
|
|
|
|
|
class ForensicDeepfakeDetector(nn.Module): |
|
|
def __init__(self, model_name="facebook/wav2vec2-xls-r-300m"): |
|
|
super().__init__() |
|
|
|
|
|
self.backbone = Wav2Vec2Model.from_pretrained(model_name) |
|
|
|
|
|
|
|
|
for param in self.backbone.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
|
|
|
self.pooling = nn.AdaptiveAvgPool1d(1) |
|
|
|
|
|
|
|
|
self.classifier = nn.Sequential( |
|
|
nn.Linear(1024, 512), |
|
|
nn.LeakyReLU(0.2), |
|
|
nn.Dropout(0.3), |
|
|
nn.Linear(512, 1) |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
with torch.no_grad(): |
|
|
outputs = self.backbone(x).last_hidden_state |
|
|
|
|
|
outputs = outputs.permute(0, 2, 1) |
|
|
pooled = self.pooling(outputs).squeeze(-1) |
|
|
return self.classifier(pooled) |
|
|
|