| import pathlib |
| import tempfile |
| import logging |
| import os |
| import copy |
|
|
| import torch |
| from torch import nn |
|
|
| from timm.models.layers import trunc_normal_ |
|
|
| from .ImageEncoder import build_image_encoder |
| from .LangEncoder import build_lang_encoder |
| from .LangEncoder import build_tokenizer |
|
|
| import mup.init |
| from mup import set_base_shapes |
|
|
| from safetensors.torch import load_file |
|
|
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class UniCLModel(nn.Module): |
| def __init__(self, config: dict): |
| super().__init__() |
|
|
| self.conf_lang_encoder = config['LANG_ENCODER'] |
| self.tokenizer = build_tokenizer(self.conf_lang_encoder) |
|
|
| self.lang_encoder = build_lang_encoder(self.conf_lang_encoder, self.tokenizer, config['VERBOSE']) |
|
|
| dim_projection = config['UNICL_MODEL']['DIM_PROJECTION'] |
| if hasattr(self.lang_encoder, 'dim_out'): |
| dim_out = self.lang_encoder.dim_out |
| else: |
| with torch.no_grad(): |
| dim_out = self.lang_encoder( |
| torch.zeros(1,1).type(torch.LongTensor) |
| )['last_hidden_state'].size(2) |
|
|
| self.lang_projection = nn.Parameter(torch.empty(dim_out, dim_projection)) |
|
|
| self.conf_image_encoder = config['IMAGE_ENCODER'] |
| self.image_encoder = build_image_encoder(self.conf_image_encoder, config['VERBOSE']) |
|
|
| self.image_projection = nn.Parameter( |
| torch.empty(self.image_encoder.dim_out, dim_projection) |
| ) |
|
|
| self.logit_scale = nn.Parameter(torch.ones([])) |
|
|
| if torch.cuda.is_available(): |
| self.device = torch.device(type="cuda", index=0) |
| else: |
| self.device = torch.device(type="cpu") |
|
|
| def custom_init_weights(self, use_original_init=True): |
| self.use_original_init = use_original_init |
| logger.info('Custom init: {}'.format('original init' if self.use_original_init else 'muP init')) |
|
|
| if self.use_original_init: |
| |
| |
| custom_trunc_normal_ = trunc_normal_ |
| else: |
| |
| custom_trunc_normal_ = mup.init.trunc_normal_ |
|
|
| custom_trunc_normal_(self.lang_projection, std=.02) |
| custom_trunc_normal_(self.image_projection, std=.02) |
|
|
| def _convert_old_weights(self, model_dict): |
| model_dict_updated = {} |
| for k, v in model_dict.items(): |
| if k.startswith('visual.'): |
| model_dict_updated['image_encoder.'+k[7:]] = v |
| elif k.startswith('text.'): |
| model_dict_updated['lang_encoder.'+k[5:]] = v |
| elif k == 'vision_projection': |
| model_dict_updated['image_projection'] = v |
| elif k == 'text_projection': |
| model_dict_updated['lang_projection'] = v |
| else: |
| model_dict_updated[k] = v |
|
|
| return model_dict_updated |
|
|
| def from_pretrained(self, pretrained='', pretrained_layers=[], verbose=True): |
| if not os.path.isfile(pretrained): |
| logger.warning(f'=> Pretrained model ({pretrained}) is not a file, skip init weight') |
| return |
|
|
| |
| pretrained_dict = load_file(pretrained) |
| logger.info(f'=> Loading pretrained model {pretrained}') |
| model_dict = self.state_dict() |
| pretrained_dict = self._convert_old_weights(pretrained_dict) |
| |
| pretrained_dict = { |
| k: v.to(self.device) for k, v in pretrained_dict.items() |
| } |
| need_init_state_dict = {} |
| image_encoder_state_dict = {} |
| for k, v in pretrained_dict.items(): |
| need_init = ( |
| k.split('.')[0] in pretrained_layers |
| or pretrained_layers[0] == '*' |
| ) |
|
|
| if need_init: |
| if k.startswith('image_encoder.'): |
| image_encoder_state_dict[k] = v.to(self.device) |
| else: |
| if verbose: |
| logger.info(f'=> init {k} from {pretrained}') |
|
|
| if 'positional_embedding' in k and v.size() != model_dict[k].size(): |
| positional_embedding_pretrained = v |
| positional_embedding_current = model_dict[k] |
| L1, nH1 = positional_embedding_pretrained.size() |
| L2, nH2 = positional_embedding_current.size() |
| if nH1 != nH2: |
| logger.info(f"Error in loading {k}, passing") |
| else: |
| if L1 != L2: |
| logger.info( |
| '=> load_pretrained: resized variant: {} to {}' |
| .format((L1, nH1), (L2, nH2)) |
| ) |
|
|
| posemb = positional_embedding_pretrained.float() |
| posemb_grid = posemb.unsqueeze(dim=0).permute(0, 2, 1) |
| posemb_grid = torch.nn.functional.interpolate(posemb_grid, size=L2, mode='linear') |
| posemb_grid = posemb_grid.permute(0, 2, 1).squeeze(dim=0) |
| v = posemb_grid |
|
|
| need_init_state_dict[k] = v.to(self.device) |
| self.image_encoder.from_state_dict(image_encoder_state_dict, ['*'], verbose) |
| self.load_state_dict(need_init_state_dict, strict=False) |
|
|
| @torch.jit.ignore |
| def no_weight_decay(self): |
| no_weight_decay = {'logit_scale'} |
| if hasattr(self.lang_encoder, 'no_weight_decay'): |
| for k in self.lang_encoder.no_weight_decay(): |
| no_weight_decay.add('lang_encoder.'+k) |
|
|
| if hasattr(self.image_encoder, 'no_weight_decay'): |
| for k in self.visual.no_weight_decay(): |
| no_weight_decay.add('image_encoder.'+k) |
|
|
| return no_weight_decay |
|
|
| @property |
| def dtype(self): |
| return self.logit_scale.dtype |
|
|
| def encode_image(self, image, norm=True): |
| x = self.image_encoder.forward_features(image) |
| x = x @ self.image_projection |
|
|
| if norm: |
| x = x / x.norm(dim=-1, keepdim=True) |
|
|
| return x |
|
|
| def encode_text(self, text, norm=True): |
| x = self.lang_encoder(**text) |
| x = x['last_hidden_state'] |
|
|
| if self.conf_lang_encoder['TOKENIZER'] == 'clip': |
| x = x[torch.arange(x.size(0)), text['input_ids'].argmax(dim=-1)] |
| else: |
| x = x[:, 0] |
|
|
| x = x @ self.lang_projection |
|
|
| if norm: |
| x = x / x.norm(dim=-1, keepdim=True) |
|
|
| return x |
|
|
| def forward(self, image, text): |
| features_image = self.encode_image(image) |
| features_text = self.encode_text(text) |
|
|
| |
| T = self.logit_scale.exp() |
|
|
| return features_image, features_text, T |
|
|
|
|
| def create_model(config): |
| model = UniCLModel(config) |
| return model |
|
|
|
|
| def create_mup_model(config): |
| def gen_config(config, wm): |
| |
| |
| assert (not config['UNICL_MODEL']['STANDPARAM']) and \ |
| (not config['LANG_ENCODER']['STANDPARAM']) and \ |
| (not config['IMAGE_ENCODER']['SPEC']['STANDPARAM']) |
| new_config = copy.deepcopy(config) |
| logger.info(f'Generate config with width mult = {wm}:') |
|
|
| |
| new_config_section = new_config['UNICL_MODEL'] |
| new_config_section['STANDPARAM'] = True |
| for name in ['DIM_PROJECTION']: |
| base_name = 'BASE_' + name |
| new_values = round(new_config_section[base_name] * wm) |
| logger.info(f'config["UNICL_MODEL"]["{name}"]: {new_config_section[name]} -> {new_values}') |
| new_config_section[name] = new_values |
|
|
| |
| new_config_section = new_config['LANG_ENCODER'] |
| new_config_section['STANDPARAM'] = True |
| for name in ['WIDTH', 'HEADS']: |
| base_name = 'BASE_' + name |
| new_values = round(new_config_section[base_name] * wm) |
| logger.info(f'config["LANG_ENCODER"]["{name}"]: {new_config_section[name]} -> {new_values}') |
| new_config_section[name] = new_values |
|
|
| |
| new_config_section = new_config['IMAGE_ENCODER']['SPEC'] |
| new_config_section['STANDPARAM'] = True |
| for name in ['DIM_EMBED', 'NUM_HEADS', 'NUM_GROUPS']: |
| base_name = 'BASE_' + name |
| new_values = [round(base_value * wm) for base_value in new_config_section[base_name]] |
| logger.info(f'config["IMAGE_ENCODER"]["SPEC"]["{name}"]: {new_config_section[name]} -> {new_values}') |
| new_config_section[name] = new_values |
|
|
| return new_config |
|
|
| logger.info('muP: Create models and set base shapes') |
| logger.info('=> Create model') |
| model = create_model(config) |
| |
| |
| lang_encoder, image_encoder = model.lang_encoder, model.image_encoder |
| model.lang_encoder, model.image_encoder = None, None |
|
|
| logger.info('=> Create base model') |
| base_config = gen_config(config, wm=1.0) |
| base_model = create_model(base_config) |
| del base_model.lang_encoder, base_model.image_encoder |
|
|
| logger.info('=> Create delta model') |
| delta_config = gen_config(config, wm=2.0) |
| delta_model = create_model(delta_config) |
| del delta_model.lang_encoder, delta_model.image_encoder |
|
|
| logger.info('=> Set base shapes in model for training') |
| set_base_shapes(model, base=base_model, delta=delta_model) |
|
|
| |
| model.lang_encoder, model.image_encoder = lang_encoder, image_encoder |
|
|
| return model |
|
|
|
|
| def build_unicl_model(config, **kwargs): |
| standparam = config['UNICL_MODEL'].get('STANDPARAM', True) |
|
|
| if standparam: |
| logger.info('Create model with standard parameterization') |
| model = create_model(config) |
|
|
| use_original_init = True |
| else: |
| logger.info('Create model with mu parameterization') |
| model = create_mup_model(config) |
| use_original_init = False |
|
|
| |
| model.custom_init_weights(use_original_init=use_original_init) |
|
|
| if config['UNICL_MODEL']['LOAD_PRETRAINED']: |
| pretrained_path = config['UNICL_MODEL']['PRETRAINED'] |
| from .Distributed.Utils import is_valid_url, download_file |
| if is_valid_url(pretrained_path): |
| with tempfile.TemporaryDirectory() as tmp_path: |
| file_local_path = pathlib.Path(tmp_path) / 'base_model.pt' |
| download_file(pretrained_path, file_local_path) |
| model.from_pretrained(str(file_local_path), config['UNICL_MODEL']['PRETRAINED_LAYERS'], config['VERBOSE']) |
| else: |
| model.from_pretrained(pretrained_path, config['UNICL_MODEL']['PRETRAINED_LAYERS'], config['VERBOSE']) |
|
|
| return model |
|
|