| ''' |
| * Adapted from BLIP (https://github.com/salesforce/BLIP) |
| ''' |
|
|
| import warnings |
| warnings.filterwarnings("ignore") |
|
|
| import torch |
| import os |
| from urllib.parse import urlparse |
| from timm.models.hub import download_cached_file |
| from transformers import BertTokenizer |
| from .vit import VisionTransformer, interpolate_pos_embed |
|
|
|
|
| def default_bert(): |
| current_dir = os.path.dirname(os.path.abspath(__file__)) |
| project_root = os.path.abspath(os.path.join(current_dir, '../../../../')) |
| model_path = os.path.join(project_root, 'models', 'QualityMetric') |
| return os.path.join(model_path, "bert-base-uncased") |
|
|
|
|
| def init_tokenizer(bert_model_path): |
| tokenizer = BertTokenizer.from_pretrained(bert_model_path) |
| tokenizer.add_special_tokens({'bos_token':'[DEC]'}) |
| tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']}) |
| tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0] |
| return tokenizer |
|
|
|
|
| def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0): |
| |
| assert vit in ['base', 'large'], "vit parameter must be base or large" |
| if vit=='base': |
| vision_width = 768 |
| visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12, |
| num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer, |
| drop_path_rate=0 or drop_path_rate |
| ) |
| elif vit=='large': |
| vision_width = 1024 |
| visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24, |
| num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer, |
| drop_path_rate=0.1 or drop_path_rate |
| ) |
| return visual_encoder, vision_width |
|
|
|
|
| def is_url(url_or_filename): |
| parsed = urlparse(url_or_filename) |
| return parsed.scheme in ("http", "https") |
|
|
| def load_checkpoint(model,url_or_filename): |
| if is_url(url_or_filename): |
| cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True) |
| checkpoint = torch.load(cached_file, map_location='cpu') |
| elif os.path.isfile(url_or_filename): |
| checkpoint = torch.load(url_or_filename, map_location='cpu') |
| else: |
| raise RuntimeError('checkpoint url or path is invalid') |
| |
| state_dict = checkpoint['model'] |
| |
| state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder) |
| if 'visual_encoder_m.pos_embed' in model.state_dict().keys(): |
| state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'], |
| model.visual_encoder_m) |
| for key in model.state_dict().keys(): |
| if key in state_dict.keys(): |
| if state_dict[key].shape!=model.state_dict()[key].shape: |
| print(key, ": ", state_dict[key].shape, ', ', model.state_dict()[key].shape) |
| del state_dict[key] |
| |
| msg = model.load_state_dict(state_dict,strict=False) |
| print('load checkpoint from %s'%url_or_filename) |
| return model,msg |
| |
|
|