| import torch |
| import torch.nn as nn |
|
|
| from sat.model import ViTModel, BaseModel |
| from sat.model import BaseMixin |
| from torchvision import transforms |
| from torchvision.transforms.functional import InterpolationMode |
|
|
| class LNFinalyMixin(BaseMixin): |
| def __init__(self, hidden_size): |
| super().__init__() |
| self.ln_vision = nn.LayerNorm(hidden_size) |
|
|
| def final_forward(self, logits, **kw_args): |
| return self.ln_vision(logits) |
|
|
|
|
| class EVAViT(ViTModel): |
| def __init__(self, args, transformer=None, parallel_output=True, **kwargs): |
| super().__init__(args, transformer=transformer, parallel_output=parallel_output, **kwargs) |
| self.del_mixin("cls") |
| self.add_mixin("cls", LNFinalyMixin(args.hidden_size)) |
|
|
| def forward(self, image): |
| batch_size = image.size(0) |
| input_ids = torch.zeros(batch_size, 1, dtype=torch.long, device=image.device) |
| attention_mask = torch.tensor([[1.]], dtype=image.dtype, device=image.device) |
| return super().forward(input_ids=input_ids, position_ids=None, attention_mask=attention_mask, image=image) |
|
|
|
|
| class QFormer(BaseModel): |
| def __init__(self, args, transformer=None, parallel_output=True, **kwargs): |
| super().__init__(args, transformer=transformer, parallel_output=parallel_output, |
| activation_func=nn.functional.gelu, **kwargs) |
| self.transformer.position_embeddings = None |
|
|
| def final_forward(self, logits, **kw_args): |
| return logits |
|
|
| def position_embedding_forward(self, position_ids, **kw_args): |
| return None |
|
|
| def forward(self, encoder_outputs): |
| batch_size = encoder_outputs.size(0) |
| input_ids = torch.arange(32, dtype=torch.long, device=encoder_outputs.device).unsqueeze(0).expand(batch_size, |
| -1) |
| attention_mask = torch.tensor([[1.]], dtype=encoder_outputs.dtype, device=encoder_outputs.device) |
| cross_attention_mask = torch.tensor([[1.]], dtype=encoder_outputs.dtype, device=encoder_outputs.device) |
| return super().forward(input_ids=input_ids, position_ids=None, attention_mask=attention_mask, |
| encoder_outputs=encoder_outputs, cross_attention_mask=cross_attention_mask) |
|
|
|
|
| class BLIP2(torch.nn.Module): |
| def __init__(self, eva_args, qformer_args, vit=None, qformer=None, **kwargs): |
| super().__init__() |
| if vit is not None: |
| self.vit = vit |
| else: |
| self.vit = EVAViT(EVAViT.get_args(**eva_args)) |
| if qformer is not None: |
| self.qformer = qformer |
| else: |
| self.qformer = QFormer(QFormer.get_args(**qformer_args)) |
|
|
| self.glm_proj = nn.Linear(768, 4096).to(self.qformer.parameters().__next__().device).to( |
| self.qformer.parameters().__next__().dtype) |
|
|
| def forward(self, image, **kwargs): |
| enc = self.vit(image)[0] |
| out = self.qformer(enc)[0] |
| return self.glm_proj(out) |
|
|
|
|
| class BlipImageBaseProcessor(): |
| def __init__(self, mean=None, std=None): |
| if mean is None: |
| mean = (0.48145466, 0.4578275, 0.40821073) |
| if std is None: |
| std = (0.26862954, 0.26130258, 0.27577711) |
|
|
| self.normalize = transforms.Normalize(mean, std) |
|
|
|
|
| class BlipImageEvalProcessor(BlipImageBaseProcessor): |
| def __init__(self, image_size=384, mean=None, std=None): |
| super().__init__(mean=mean, std=std) |
|
|
| self.transform = transforms.Compose( |
| [ |
| transforms.Resize( |
| (image_size, image_size), interpolation=InterpolationMode.BICUBIC |
| ), |
| transforms.ToTensor(), |
| self.normalize, |
| ] |
| ) |
|
|
| def __call__(self, item): |
| return self.transform(item) |
|
|