| from torchvision.models import resnet50, ResNet50_Weights |
| from transformers import PreTrainedModel |
| from .config import ResnetConfig |
| import torch.nn as nn |
|
|
| class ResNet50(nn.Module): |
| def __init__(self, ): |
| super().__init__() |
| self.cnn = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1) |
| self.backbone = nn.Sequential(*list(self.cnn.children())[:-2]) |
| self.flaten = nn.Sequential(nn.AvgPool2d(kernel_size=7), nn.Flatten()) |
| self.fc_1 = nn.Linear(2048, 768) |
|
|
| def forward(self, x): |
| if len(x.shape) == 3: |
| x = x.unsqueeze(0) |
| x = self.backbone(x) |
| x = self.flaten(x) |
| x = self.fc_1(x) |
| x = x.squeeze(0) |
| return x |
| |
| class ResNet50AffectiveFeatureExtractor(PreTrainedModel): |
| config_class = ResnetConfig |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.model = ResNet50() |
| del self.model.cnn |
|
|
| def forward(self, tensor): |
| return self.model(tensor) |