| from transformers import PretrainedConfig, PreTrainedModel, BertModel, BertConfig
|
| import timm
|
| import torch.nn as nn
|
| import torch
|
| import numpy
|
| from torchvision import transforms
|
| from PIL import Image
|
|
|
| class RenameLayerScale(nn.Module):
|
| def __init__(
|
| self,
|
| dim: int,
|
| init_values: float = 1e-5,
|
| inplace: bool = False,
|
| ) -> None:
|
| super().__init__()
|
| self.inplace = inplace
|
| self.weight = nn.Parameter(init_values * torch.ones(dim))
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| return x.mul_(self.weight) if self.inplace else x * self.weight
|
|
|
| timm.models.vision_transformer.LayerScale = RenameLayerScale
|
|
|
| class KEEPConfig(PretrainedConfig):
|
| model_type = "keep"
|
|
|
| def __init__(
|
| self,
|
| vision_config=None,
|
| text_config=None,
|
| projection_dim=768,
|
| **kwargs,
|
| ):
|
| super().__init__(**kwargs)
|
| self.vision_config = vision_config
|
| self.text_config = text_config
|
| self.projection_dim = projection_dim
|
|
|
| class KEEPModel(PreTrainedModel):
|
| config_class = KEEPConfig
|
|
|
| def __init__(self, config):
|
| super().__init__(config)
|
|
|
|
|
| vision_config = config.vision_config
|
| self.visual = timm.create_model(
|
| "vit_large_patch16_224",
|
| pretrained=False,
|
| img_size=vision_config["img_size"],
|
| patch_size=vision_config["patch_size"],
|
| init_values=vision_config["init_values"],
|
| num_classes=vision_config["num_classes"],
|
| )
|
|
|
| self.visual_head = nn.Sequential(
|
| nn.Linear(self.visual.num_features, config.projection_dim),
|
| nn.GELU(),
|
| nn.Linear(config.projection_dim, config.projection_dim)
|
| )
|
|
|
|
|
| text_config = BertConfig(**config.text_config)
|
| self.text = BertModel(text_config)
|
|
|
| self.logit_scale = nn.Parameter(torch.ones([]) * numpy.log(1 / 0.04))
|
|
|
| def encode_image(self, image_inputs):
|
| vision_features = self.visual(image_inputs)
|
| vision_features = torch.nn.functional.normalize(self.visual_head(vision_features), dim=-1)
|
|
|
| return vision_features
|
|
|
| def encode_text(self, text_inputs):
|
| text_features = torch.nn.functional.normalize(self.text(**text_inputs).pooler_output, dim=-1)
|
| return text_features
|
|
|
|
|
| def forward(self, image_inputs, text_inputs):
|
| vision_features = self.encode_image(image_inputs)
|
|
|
| text_features = self.encode_text(text_inputs)
|
|
|
| return {
|
| "vision_features": vision_features,
|
| "text_features": text_features
|
| } |