ProFound / models /classifier.py
Anonymise's picture
add necessary module
45461c9
raw
history blame contribute delete
610 Bytes
import torch
import torch.nn as nn
class Classifier(nn.Module):
def __init__(self, encoder, num_classes, bottleneck_dim=256):
super().__init__()
self.encoder = encoder
self.embed_dim = self.encoder.embed_dim
self.head = torch.nn.Sequential(
nn.Linear(self.embed_dim, bottleneck_dim),
nn.BatchNorm1d(bottleneck_dim),
nn.ReLU(),
nn.Linear(bottleneck_dim, num_classes)
)
def forward(self, x):
x = self.encoder(x)
if type(x) == tuple:
x = x[0]
x = self.head(x)
return x