| import os |
| import re |
| import math |
| import torch |
| import torch.nn as nn |
| from .clip_encoder import CLIPVisionTower |
| from .eva_clip_encoder import EvaClipVisionTower |
| from .siglip_encoder import SiglipVisionTower |
| from .google_siglip_encoder import GoogleSiglipVisionTower |
| from llava.model.utils import LayerNorm |
| from .qformer import BertConfig, BertLMHeadModel |
| from .resampler import Resampler, TokenCompressor |
| from torch.nn.init import trunc_normal_ |
|
|
|
|
|
|
|
|
|
|
| def build_vision_tower(vision_tower_cfg, **kwargs): |
| vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None)) |
| |
| if vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower: |
| vision_tower = CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) |
| elif vision_tower.startswith("eva"): |
| vision_tower = EvaClipVisionTower(vision_tower, args=vision_tower_cfg) |
| elif vision_tower.startswith("google/siglip"): |
| vision_tower = GoogleSiglipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) |
| elif 'HuggingFaceM4/siglip' in vision_tower: |
| vision_tower = SiglipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) |
| else: |
| raise ValueError(f'Unknown vision tower: {vision_tower}') |
| |
| return vision_tower |
|
|
|
|
|
|
| def build_Qformer(num_query_token, vision_width, extra_num_query_token=64, cross_attention_freq=2): |
| ln_vision = LayerNorm(vision_width) |
| encoder_config = BertConfig.from_pretrained("./model/bert-base-uncased") |
| encoder_config.encoder_width = vision_width |
| |
| encoder_config.add_cross_attention = True |
| encoder_config.cross_attention_freq = cross_attention_freq |
| encoder_config.query_length = num_query_token |
| Qformer = BertLMHeadModel(config=encoder_config) |
| query_tokens = nn.Parameter( |
| torch.zeros(1, num_query_token, encoder_config.hidden_size) |
| ) |
| query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) |
| |
| Qformer.cls = None |
| Qformer.bert.embeddings.word_embeddings = None |
| Qformer.bert.embeddings.position_embeddings = None |
| for layer in Qformer.bert.encoder.layer: |
| layer.output = None |
| layer.intermediate = None |
|
|
| return Qformer, ln_vision, query_tokens |
|
|
| |
| def build_adapter_module(cfg, vision_width): |
| return AdapterModule(cfg, vision_width) |
|
|
|
|
| class IdentityMap(nn.Module): |
| def __init__(self): |
| super().__init__() |
|
|
| def forward(self, x, *args, **kwargs): |
| return x |
|
|
|
|
| class AdapterModule(nn.Module): |
| def __init__(self, config, vision_width): |
| super().__init__() |
| self.adapter_name = config.adapter_module_name |
| self.config = config |
| self.output_dim = vision_width |
| if 'perceiver' in self.adapter_name: |
| from flash_perceiver import Perceiver |
| self.adapter = Perceiver( |
| input_dim=vision_width, |
| depth=6, |
| output_dim=vision_width, |
| num_latents=self.config.num_query_token, |
| latent_dim=1024, |
| cross_heads=1, |
| cross_head_dim=128, |
| cross_rotary_emb_dim=0, |
| cross_attn_dropout=0.0, |
| latent_heads=8, |
| latent_head_dim=128, |
| latent_rotary_emb_dim=0, |
| latent_attn_dropout=0.0, |
| weight_tie_layers=False, |
| gated_mlp=True, |
| self_per_cross_attn=1, |
| num_zero_tokens=None, |
| use_flash_attn=True, |
| ) |
| elif 'naive_resampler' in self.adapter_name: |
| assert math.sqrt(self.config.num_query_token) ** 2 == self.config.num_query_token, 'num of query need to be a square number' |
| self.adapter = Resampler( |
| grid_size=int(math.sqrt(self.config.num_query_token)), |
| embed_dim=vision_width, |
| num_heads=8, |
| ) |
| elif 'qformer' in self.adapter_name: |
| Qformer, ln_vision, query_tokens = build_Qformer( |
| self.config.num_query_token, vision_width) |
| self.adapter = Qformer |
| self.ln_vision = ln_vision |
| self.query_tokens = query_tokens |
| self.output_dim = Qformer.config.hidden_size |
| elif 'none' in self.adapter_name: |
| self.adapter = IdentityMap() |
| |
| self.is_loaded = False |
| |
| if 'compress_token' in self.adapter_name: |
| match = re.search(r'\d+$', self.adapter_name) |
| self.token_compressor = TokenCompressor( |
| num_compressed_token=int(match.group()), |
| embed_dim=self.config.hidden_size, |
| num_heads=8, |
| ) |
| if 'v1' in self.adapter_name: |
| self.compress_version = 'v1' |
| else: |
| self.compress_version = 'v0' |
|
|
| |
| self.frame_position_encoding = nn.Embedding( |
| config.max_num_segments, |
| self.output_dim, |
| ) |
| |
| self.adapter.apply(self._init_weights) |
| |
| def _init_weights(self, m): |
| if isinstance(m, (nn.Linear, nn.Embedding)): |
| trunc_normal_(m.weight, std=.02) |
| if isinstance(m, nn.Linear) and m.bias is not None: |
| nn.init.constant_(m.bias, 0) |
| elif isinstance(m, nn.LayerNorm): |
| nn.init.constant_(m.bias, 0) |
| nn.init.constant_(m.weight, 1.0) |
| |
| def forward(self, image_features, frame_ids): |
| if 'perceiver' in self.adapter_name: |
| adapted_image_features = self.adapter(image_features, return_embeddings=True) |
| elif 'naive_resampler' in self.adapter_name: |
| adapted_image_features = self.adapter(image_features) |
| elif 'qformer' in self.adapter_name: |
| image_features = self.ln_vision(image_features) |
| query_tokens = self.query_tokens.expand(image_features.shape[0], -1, -1) |
| attn_mask = torch.ones(image_features.size()[:-1], dtype=torch.long).to(image_features.device) |
| adapted_image_features = self.adapter.bert( |
| query_embeds=query_tokens, |
| encoder_hidden_states=image_features, |
| encoder_attention_mask=attn_mask, |
| return_dict=True |
| ).last_hidden_state |
| elif 'none' in self.adapter_name: |
| adapted_image_features = self.adapter(image_features) |
| |
| frame_embeddings = self.frame_position_encoding(frame_ids).unsqueeze(-2) |
| adapted_image_features += frame_embeddings |
| return adapted_image_features |
| |
| |
| def compress_token_per_img(self, batch_image_features): |
| if 'compress_token' not in self.adapter_name: |
| return batch_image_features |
| compressed_features = [] |
| for image_features in batch_image_features: |
| |
| if image_features.shape[1] < self.token_compressor.num_compressed_token: |
| compressed_features.append(image_features) |
| else: |
| compressed_features.append(self.token_compressor(image_features, compress_version=self.compress_version)) |
| return compressed_features |
|
|
|
|
| def load_model(self): |
| if self.is_loaded: |
| return |
|
|
| if getattr(self.config, 'adapter_module_path', None): |
| checkpoint = torch.load(self.config.adapter_module_path, map_location="cpu") |
| |
| def get_w(weights, keyword): |
| return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword + '.' in k} |
| |
| def get_variable_frame_encoding_w(model_weights, load_weights): |
| keyword = 'frame_position_encoding' |
| model_len = model_weights.shape[0] |
| load_weights_f_encoding = get_w(load_weights, keyword) |
|
|
| load_len = load_weights_f_encoding['weight'].shape[0] |
| if model_len <= load_len: |
| value = load_weights_f_encoding['weight'][:model_len] |
| else: |
| value = model_weights.clone().cpu() |
| value[:load_len] = load_weights_f_encoding['weight'] |
| return value |
| |
| if 'qformer' in self.adapter_name and ('projector.bin' not in self.config.adapter_module_path): |
| state_dict = checkpoint["model"] |
| self.adapter.load_state_dict(get_w(state_dict, 'Qformer')) |
| self.ln_vision.load_state_dict(get_w(state_dict, 'ln_vision')) |
| self.load_state_dict({'query_tokens': state_dict['query_tokens']}, strict=False) |
| if getattr(self.config, 'pretrain_mm_mlp_adapter', None): |
| mm_projector_weights = torch.load(self.config.pretrain_mm_mlp_adapter, map_location='cpu') |
| frame_encoding_weight = get_variable_frame_encoding_w(self.frame_position_encoding.weight, mm_projector_weights) |
| self.frame_position_encoding.load_state_dict({'weight': frame_encoding_weight}) |
| else: |
| frame_encoding_weight = get_variable_frame_encoding_w(self.frame_position_encoding.weight, checkpoint) |
| for k in checkpoint.keys(): |
| if 'frame_position_encoding' in k: |
| checkpoint[k] = frame_encoding_weight |
| |
| self.load_state_dict(get_w(checkpoint, 'adapter_module')) |
| else: |
| |
| return |
|
|
| def freeze_adapter_module(self, freeze_flag): |
| if freeze_flag: |
| for name, p in self.named_parameters(): |
| p.requires_grad = False |
| else: |
| for name, p in self.named_parameters(): |
| p.requires_grad = True |
|
|
| if 'naive_resampler' in self.adapter_name: |
| for name, p in self.named_parameters(): |
| if 'pos_embed' in name: |
| p.requires_grad = False |
|
|