| |
| import math |
| import os.path as osp |
| import warnings |
| from collections import OrderedDict |
| from typing import List, Optional |
| import torch.nn.functional as F |
| import torch |
| import torch.nn as nn |
| from accelerate import init_empty_weights |
| from mmengine import print_log |
| from mmengine.config import Config, ConfigDict |
| from mmengine.model import BaseModel |
| from peft import get_peft_model, prepare_model_for_kbit_training |
| from transformers import (AddedToken, AutoConfig, CLIPImageProcessor, PreTrainedModel, |
| CLIPVisionModel, LlamaForCausalLM, |
| LlamaTokenizerFast, LlavaConfig, |
| LlavaForConditionalGeneration, LlavaProcessor) |
| from transformers.integrations import is_deepspeed_zero3_enabled |
|
|
| from xtuner.registry import BUILDER |
| from xtuner.model.modules.dispatch import SUPPORT_FLASH1, SUPPORT_FLASH2 |
| from xtuner.model.utils import (LoadWoInit, find_all_linear_names, |
| get_peft_model_state_dict, guess_load_checkpoint, |
| make_inputs_require_grad, traverse_dict) |
| from xtuner.utils import IGNORE_INDEX, IMAGE_TOKEN_INDEX |
| from xtuner.tools.utils import get_stop_criteria, is_cn_string |
| from transformers import GenerationConfig |
| from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, |
| PROMPT_TEMPLATE) |
|
|
|
|
| def convert_state_dict_to_hf(state_dict, mapping): |
| new_state_dict = {} |
| for key, value in state_dict.items(): |
| if key.endswith('.inv_freq'): |
| continue |
| for key_to_modify, new_key in mapping.items(): |
| if key_to_modify in key: |
| key = key.replace(key_to_modify, new_key) |
| new_state_dict[key] = value |
| return new_state_dict |
|
|
|
|
| class SingleLLaVAModelSFT(BaseModel): |
|
|
| def __init__(self, |
| llm, |
| visual_encoder=None, |
| tokenizer=None, |
| freeze_llm=False, |
| freeze_visual_encoder=False, |
| visual_select_layer=-2, |
| pretrained_pth=None, |
| projector_depth=0, |
| llm_lora=None, |
| visual_encoder_lora=None, |
| use_activation_checkpointing=True, |
| max_position_embeddings=None, |
| add_cls_token=False, |
| template=None, |
| ): |
| super().__init__() |
|
|
| if tokenizer is not None: |
| self.tokenizer = tokenizer |
| tokenizer_type = self.tokenizer['type'] |
| del self.tokenizer['type'] |
| self.tokenizer = tokenizer_type(**self.tokenizer) |
|
|
| self.freeze_llm = freeze_llm |
| self.freeze_visual_encoder = freeze_visual_encoder |
| with LoadWoInit(): |
| if isinstance(llm, dict): |
| llm = self._dispatch_lm_model_cfg(llm, max_position_embeddings) |
|
|
| self.llm = self._build_from_cfg_or_module(llm) |
|
|
| if visual_encoder is not None: |
| self.visual_encoder = self._build_from_cfg_or_module( |
| visual_encoder) |
| else: |
| self.visual_encoder = None |
|
|
| self.llm.config.use_cache = False |
|
|
| |
|
|
| self.projector_depth = projector_depth |
| |
| |
| |
| |
| |
| |
| self.projector = None |
|
|
| if self.freeze_llm: |
| self.llm.requires_grad_(False) |
| if self.freeze_visual_encoder: |
| if self.visual_encoder is not None: |
| self.visual_encoder.requires_grad_(False) |
|
|
| if use_activation_checkpointing: |
| |
| if hasattr(self.llm, 'enable_input_require_grads'): |
| self.llm.enable_input_require_grads() |
| else: |
| self.llm.get_input_embeddings().register_forward_hook( |
| make_inputs_require_grad) |
|
|
| |
| self.gradient_checkpointing_enable() |
|
|
| self.use_llm_lora = llm_lora is not None |
| self.use_visual_encoder_lora = visual_encoder_lora is not None |
|
|
| if self.use_llm_lora: |
| self._prepare_llm_for_lora(llm_lora, use_activation_checkpointing) |
| if self.use_visual_encoder_lora: |
| self._prepare_visual_encoder_for_lora( |
| visual_encoder_lora, use_activation_checkpointing) |
|
|
| if pretrained_pth is not None: |
| pretrained_state_dict = guess_load_checkpoint(pretrained_pth) |
|
|
| self.load_state_dict(pretrained_state_dict, strict=False) |
| print_log(f'Load pretrained weight from {pretrained_pth}', |
| 'current') |
|
|
| self.visual_select_layer = visual_select_layer |
|
|
| self._is_init = True |
|
|
| self.is_first_iter = True |
|
|
| self.add_cls_token = add_cls_token |
|
|
| self.template = template |
|
|
| def _parse_lora_config(self, lora_config): |
| if isinstance(lora_config, dict) or isinstance( |
| lora_config, Config) or isinstance(lora_config, ConfigDict): |
| lora_config = BUILDER.build(lora_config) |
| return lora_config |
|
|
| def _prepare_llm_for_lora(self, |
| lora_config, |
| use_activation_checkpointing=True): |
| lora_config = self._parse_lora_config(lora_config) |
| self.llm = prepare_model_for_kbit_training( |
| self.llm, use_activation_checkpointing) |
| if lora_config.target_modules is None: |
| modules = find_all_linear_names(self.llm) |
| lora_config.target_modules = modules |
| self.llm = get_peft_model(self.llm, lora_config) |
|
|
| def _prepare_visual_encoder_for_lora(self, |
| lora_config, |
| use_activation_checkpointing=True): |
| lora_config = self._parse_lora_config(lora_config) |
| if lora_config.target_modules is None: |
| modules = find_all_linear_names(self.visual_encoder) |
| lora_config.target_modules = modules |
| self.visual_encoder = get_peft_model(self.visual_encoder, lora_config) |
|
|
| def gradient_checkpointing_enable(self): |
| self.activation_checkpointing_enable() |
|
|
| def activation_checkpointing_enable(self): |
| self.llm.gradient_checkpointing_enable() |
| if self.visual_encoder is not None: |
| self.visual_encoder.gradient_checkpointing_enable() |
| if self.projector is not None: |
| self.projector.gradient_checkpointing_enable() |
|
|
| def gradient_checkpointing_disable(self): |
| self.activation_checkpointing_disable() |
|
|
| def activation_checkpointing_disable(self): |
| self.llm.gradient_checkpointing_disable() |
| if self.visual_encoder is not None: |
| self.visual_encoder.gradient_checkpointing_disable() |
| if self.projector is not None: |
| self.projector.gradient_checkpointing_disable() |
|
|
| def init_weights(self): |
| pass |
|
|
| def state_dict(self, *args, **kwargs): |
| state_dict = super().state_dict(*args, **kwargs) |
| to_return = OrderedDict() |
| |
| if self.use_visual_encoder_lora: |
| to_return.update( |
| get_peft_model_state_dict( |
| self.visual_encoder, state_dict=state_dict)) |
| elif not self.freeze_visual_encoder: |
| to_return.update({ |
| k: v |
| for k, v in state_dict.items() if 'visual_encoder.' in k |
| }) |
| |
| if self.use_llm_lora: |
| to_return.update( |
| get_peft_model_state_dict(self.llm, state_dict=state_dict)) |
| elif not self.freeze_llm: |
| to_return.update( |
| {k: v |
| for k, v in state_dict.items() if 'llm.' in k}) |
| |
| to_return.update( |
| {k: v |
| for k, v in state_dict.items() if 'projector.' in k}) |
| return to_return |
|
|
| @staticmethod |
| def _prepare_for_long_context_training(cfg, llm_cfg, |
| max_position_embeddings): |
|
|
| orig_rope_scaling = getattr(llm_cfg, 'rope_scaling', None) |
| if orig_rope_scaling is None: |
| orig_rope_scaling = {'factor': 1} |
|
|
| orig_rope_scaling_factor = orig_rope_scaling[ |
| 'factor'] if 'factor' in orig_rope_scaling.keys() else 1 |
| orig_ctx_len = getattr(llm_cfg, 'max_position_embeddings', None) |
| if orig_ctx_len: |
| orig_ctx_len *= orig_rope_scaling_factor |
| if max_position_embeddings > orig_ctx_len: |
| scaling_factor = float( |
| math.ceil(max_position_embeddings / orig_ctx_len)) |
| llm_cfg.rope_scaling = { |
| 'type': 'linear', |
| 'factor': scaling_factor |
| } |
|
|
| |
| llm_cfg.attn_implementation = 'flash_attention_2' |
| cfg.config = llm_cfg |
|
|
| return cfg, llm_cfg |
|
|
| @staticmethod |
| def _prepare_for_flash_attn(cfg, llm_cfg): |
| cls_name = type(llm_cfg).__name__ |
| SUPPORT_SDPA_ATTN = ('LlamaConfig', 'GemmaConfig', 'MistralConfig', |
| 'MixtralConfig', 'Qwen2Config', 'Qwen2MoeConfig', |
| 'Starcoder2Config', 'Starcoder2Config', |
| 'Phi3Config') |
| SUPPORT_FLASH_ATTN2 = ('InternLM2Config', 'LlamaConfig', 'GemmaConfig', |
| 'MistralConfig', 'MixtralConfig', 'Qwen2Config', |
| 'Qwen2MoeConfig', 'Starcoder2Config', |
| 'Starcoder2Config', 'Phi3Config') |
|
|
| torch_dtype = torch.bfloat16 if ( |
| torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \ |
| else torch.float16 |
|
|
| if getattr(cfg, 'attn_implementation', None) is not None: |
| |
| |
| if cfg.attn_implementation == 'flash_attention_2': |
| cfg.torch_dtype = torch_dtype |
| elif SUPPORT_FLASH2 and cls_name in SUPPORT_FLASH_ATTN2: |
| cfg.torch_dtype = torch_dtype |
| cfg.attn_implementation = 'flash_attention_2' |
| elif SUPPORT_FLASH1 and cls_name in SUPPORT_SDPA_ATTN: |
| cfg.attn_implementation = 'sdpa' |
|
|
| return cfg, llm_cfg |
|
|
| @staticmethod |
| def _prepare_for_qlora_zero3(cfg): |
| if (not is_deepspeed_zero3_enabled()) or (not hasattr( |
| cfg, 'quantization_config')): |
| return cfg |
|
|
| torch_dtype = torch.bfloat16 if ( |
| torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \ |
| else torch.float16 |
|
|
| cfg.torch_dtype = torch_dtype |
| quantization_config = cfg.quantization_config |
| quantization_config.bnb_4bit_compute_dtype = torch_dtype |
| quantization_config.bnb_4bit_quant_storage = torch_dtype |
|
|
| return cfg |
|
|
| def _dispatch_lm_model_cfg(self, cfg, max_position_embeddings=None): |
| cfg = self._prepare_for_qlora_zero3(cfg) |
| pretrained_model_name_or_path = cfg.pretrained_model_name_or_path |
| llm_cfg = AutoConfig.from_pretrained( |
| pretrained_model_name_or_path, trust_remote_code=True) |
| cfg, llm_cfg = self._prepare_for_flash_attn(cfg, llm_cfg) |
| if max_position_embeddings is not None: |
| cfg, llm_cfg = self._prepare_for_long_context_training( |
| cfg, llm_cfg, max_position_embeddings) |
| return cfg |
|
|
| def _build_from_cfg_or_module(self, cfg_or_mod): |
| if isinstance(cfg_or_mod, nn.Module): |
| return cfg_or_mod |
| elif isinstance(cfg_or_mod, dict): |
| traverse_dict(cfg_or_mod) |
| return BUILDER.build(cfg_or_mod) |
| else: |
| raise NotImplementedError |
|
|
| def forward(self, data, data_samples=None, mode='loss'): |
| if self.is_first_iter: |
| |
| |
| |
| |
| self.to(data['input_ids'].device) |
| self.is_first_iter = False |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| data = prepare_inputs_labels_for_multimodal_solo(llm=self.llm, tokenizer=self.tokenizer, **data, |
| add_CLS=self.add_cls_token) |
|
|
| if mode == 'loss': |
| loss = self.compute_loss(data, data_samples) |
| if torch.isnan(loss["loss"]): |
| print("loss nan here") |
| return loss |
| elif mode == 'predict': |
| return self.predict(data, data_samples) |
| elif mode == 'tensor': |
| return self._forward(data, data_samples) |
| else: |
| raise NotImplementedError |
|
|
| def _forward(self, data, data_samples=None): |
|
|
| outputs = self.llm(**data) |
|
|
| return outputs |
|
|
| def predict(self, data, data_samples=None): |
| outputs = self.llm(**data) |
| logits_dict = [{'logits': logits} for logits in outputs.logits] |
| return logits_dict |
|
|
| def compute_loss(self, data, data_samples=None): |
| outputs = self.llm(**data) |
| loss_dict = {'loss': outputs.loss} |
| return loss_dict |
|
|
| def __getattr__(self, name: str): |
| |
| try: |
| return super().__getattr__(name) |
| except AttributeError: |
| return getattr(self.llm, name) |
|
|
| def to_hf(self, |
| cfg, |
| save_dir, |
| fp32=False, |
| save_pretrained_kwargs={}, |
| save_format='xtuner', |
| **kwargs): |
| if save_format == 'xtuner': |
| self.to_xtuner_llava(cfg, save_dir, fp32, save_pretrained_kwargs) |
| elif save_format == 'huggingface': |
| self.to_huggingface_llava(cfg, save_dir, fp32, |
| save_pretrained_kwargs) |
| elif save_format == 'official': |
| self.to_official_llava(cfg, save_dir, fp32, save_pretrained_kwargs) |
| else: |
| raise NotImplementedError |
|
|
| def to_xtuner_llava(self, |
| cfg, |
| save_dir, |
| fp32=False, |
| save_pretrained_kwargs={}): |
| |
| self.llm.config.use_cache = True |
| if not fp32: |
| print_log('Convert LLM to float16', 'current') |
| self.llm.half() |
| if self.use_llm_lora: |
| llm_path = osp.join(save_dir, 'llm_adapter') |
| print_log(f'Saving LLM adapter to {llm_path}', 'current') |
| self.llm.save_pretrained(llm_path, **save_pretrained_kwargs) |
| elif not self.freeze_llm: |
| llm_path = save_dir |
| print_log(f'Saving LLM tokenizer to {llm_path}', 'current') |
| tokenizer = BUILDER.build(cfg.tokenizer) |
| tokenizer.save_pretrained(llm_path, **save_pretrained_kwargs) |
| print_log(f'Saving LLM to {llm_path}', 'current') |
| self.llm.save_pretrained(llm_path, **save_pretrained_kwargs) |
| self.llm.config.use_cache = False |
|
|
| |
| if self.use_visual_encoder_lora: |
| visual_encoder_path = osp.join(save_dir, 'visual_encoder_adapter') |
| print_log( |
| f'Saving visual_encoder adapter to {visual_encoder_path}', |
| 'current') |
| self.visual_encoder.save_pretrained(visual_encoder_path, |
| **save_pretrained_kwargs) |
| elif not self.freeze_visual_encoder: |
| visual_encoder_path = osp.join(save_dir, 'visual_encoder') |
| print_log( |
| 'Saving visual_encoder image_processor to' |
| f'{visual_encoder_path}', 'current') |
| image_processor = BUILDER.build(cfg.image_processor) |
| image_processor.save_pretrained(visual_encoder_path, |
| **save_pretrained_kwargs) |
| print_log(f'Saving visual_encoder to {visual_encoder_path}', |
| 'current') |
| self.visual_encoder.save_pretrained(visual_encoder_path, |
| **save_pretrained_kwargs) |
|
|
| |
| projector_path = osp.join(save_dir, 'projector') |
| print_log(f'Saving projector to {projector_path}', 'current') |
| self.projector.save_pretrained(projector_path, |
| **save_pretrained_kwargs) |
|
|
| def to_huggingface_llava(self, |
| cfg, |
| save_dir, |
| fp32=False, |
| save_pretrained_kwargs={}): |
|
|
| LLM_MAPPING = { |
| 'model': 'language_model.model', |
| 'lm_head': 'language_model.lm_head', |
| } |
| VIT_MAPPING = { |
| 'vision_model': 'vision_tower.vision_model', |
| } |
| PROJECTOR_MAPPING = { |
| 'model.0': 'multi_modal_projector.linear_1', |
| 'model.2': 'multi_modal_projector.linear_2', |
| } |
|
|
| assert getattr(self.llm, 'hf_quantizer', None) is None, \ |
| 'This conversion format does not support quantized LLM.' |
|
|
| |
| llm = self.llm |
| if self.use_llm_lora: |
| llm = self.llm.merge_and_unload() |
| llm.config.use_cache = True |
| if not fp32: |
| print_log('Convert LLM to float16', 'current') |
| llm.half() |
|
|
| assert isinstance(llm, LlamaForCausalLM), \ |
| 'This conversion format only supports LlamaForCausalLM.' |
| llm_state_dict = llm.state_dict() |
| llm_state_dict = convert_state_dict_to_hf(llm_state_dict, LLM_MAPPING) |
|
|
| need_visual_encoder = (not self.freeze_visual_encoder |
| or self.use_visual_encoder_lora) |
| visual_encoder = self.visual_encoder |
| if self.use_visual_encoder_lora: |
| visual_encoder = self.visual_encoder.merge_and_unload() |
| assert isinstance(visual_encoder, CLIPVisionModel), \ |
| 'This conversion format only supports CLIPVisionModel.' |
| if need_visual_encoder: |
| visual_encoder_state_dict = visual_encoder.state_dict() |
| visual_encoder_state_dict = convert_state_dict_to_hf( |
| visual_encoder_state_dict, VIT_MAPPING) |
| else: |
| visual_encoder_state_dict = {} |
|
|
| projector_state_dict = self.projector.state_dict() |
| projector_state_dict = convert_state_dict_to_hf( |
| projector_state_dict, PROJECTOR_MAPPING) |
|
|
| state_dict = { |
| **projector_state_dict, |
| **llm_state_dict, |
| **visual_encoder_state_dict |
| } |
|
|
| |
| text_config = llm.config |
| vision_config = visual_encoder.config |
| config = LlavaConfig( |
| text_config=text_config, |
| vision_config=vision_config, |
| attn_implementation='eager') |
|
|
| with init_empty_weights(): |
| with warnings.catch_warnings(): |
| warnings.filterwarnings( |
| 'ignore', message='.*non-meta.*', category=UserWarning) |
| model = LlavaForConditionalGeneration(config) |
| model.load_state_dict(state_dict, strict=True, assign=True) |
|
|
| |
| cfg.tokenizer.type = LlamaTokenizerFast.from_pretrained |
| tokenizer = BUILDER.build(cfg.tokenizer) |
|
|
| tokenizer.add_tokens( |
| AddedToken(DEFAULT_IMAGE_TOKEN, special=True, normalized=False), |
| special_tokens=True) |
| tokenizer.add_special_tokens({'pad_token': '<pad>'}) |
|
|
| image_processor = BUILDER.build(cfg.image_processor) |
| assert isinstance(image_processor, CLIPImageProcessor), \ |
| 'This conversion format only supports CLIPImageProcessor.' |
|
|
| processor = LlavaProcessor( |
| tokenizer=tokenizer, image_processor=image_processor) |
|
|
| |
| pad_shape = 64 |
|
|
| pre_expansion_embeddings = \ |
| model.language_model.model.embed_tokens.weight.data |
| mu = torch.mean(pre_expansion_embeddings, dim=0).float() |
| n = pre_expansion_embeddings.size()[0] |
| sigma = ((pre_expansion_embeddings - mu).T |
| @ (pre_expansion_embeddings - mu)) / n |
| dist = torch.distributions.multivariate_normal.MultivariateNormal( |
| mu, covariance_matrix=1e-5 * sigma) |
|
|
| |
| ori_vocab_size = config.text_config.vocab_size |
| tokenizer_vocab_size = tokenizer.encode('<pad>')[-1] |
| added_token = tokenizer_vocab_size - ori_vocab_size |
|
|
| if added_token > 0: |
| model.resize_token_embeddings(ori_vocab_size + added_token, |
| pad_shape) |
| model.language_model.model.embed_tokens.weight.data[ |
| ori_vocab_size:] = torch.stack( |
| tuple( |
| dist.sample() |
| for _ in range(model.language_model.model.embed_tokens. |
| weight.data[ori_vocab_size:].shape[0])), |
| dim=0, |
| ) |
| model.language_model.lm_head.weight.data[ |
| ori_vocab_size:] = torch.stack( |
| tuple(dist.sample() |
| for _ in range(model.language_model.lm_head.weight. |
| data[ori_vocab_size:].shape[0])), |
| dim=0, |
| ) |
| model.config.image_token_index = tokenizer.encode( |
| DEFAULT_IMAGE_TOKEN)[-1] |
| model.config.pad_token_id = tokenizer.encode('<pad>')[-1] |
|
|
| |
| print_log(f'Saving to {save_dir}', 'current') |
| model.save_pretrained(save_dir, **save_pretrained_kwargs) |
| processor.save_pretrained(save_dir, **save_pretrained_kwargs) |
|
|
| def to_official_llava(self, |
| cfg, |
| save_dir, |
| fp32=False, |
| save_pretrained_kwargs={}): |
|
|
| VIT_MAPPING = { |
| 'vision_model': 'model.vision_tower.vision_tower.vision_model', |
| } |
| PROJECTOR_MAPPING = { |
| 'model.0': 'model.mm_projector.0', |
| 'model.2': 'model.mm_projector.2', |
| } |
|
|
| try: |
| from llava.model import LlavaConfig, LlavaLlamaForCausalLM |
| except ImportError: |
| raise ImportError( |
| 'Please install llava with ' |
| '`pip install git+https://github.com/haotian-liu/LLaVA.git ' |
| '--no-deps`.') |
|
|
| assert getattr(self.llm, 'hf_quantizer', None) is None, \ |
| 'This conversion format does not support quantized LLM.' |
|
|
| |
| llm = self.llm |
| if self.use_llm_lora: |
| llm = self.llm.merge_and_unload() |
| llm.config.use_cache = True |
| if not fp32: |
| print_log('Convert LLM to float16', 'current') |
| llm.half() |
|
|
| assert isinstance(llm, LlamaForCausalLM), \ |
| 'This conversion format only supports LlamaForCausalLM.' |
| llm_state_dict = llm.state_dict() |
|
|
| need_visual_encoder = (not self.freeze_visual_encoder |
| or self.use_visual_encoder_lora) |
| visual_encoder = self.visual_encoder |
| if self.use_visual_encoder_lora: |
| visual_encoder = self.visual_encoder.merge_and_unload() |
| assert isinstance(visual_encoder, CLIPVisionModel), \ |
| 'This conversion format only supports CLIPVisionModel.' |
| if need_visual_encoder: |
| visual_encoder_state_dict = visual_encoder.state_dict() |
| visual_encoder_state_dict = convert_state_dict_to_hf( |
| visual_encoder_state_dict, VIT_MAPPING) |
| else: |
| visual_encoder_state_dict = {} |
|
|
| projector_state_dict = self.projector.state_dict() |
| projector_state_dict = convert_state_dict_to_hf( |
| projector_state_dict, PROJECTOR_MAPPING) |
|
|
| state_dict = { |
| **projector_state_dict, |
| **llm_state_dict, |
| **visual_encoder_state_dict |
| } |
|
|
| |
| tokenizer = BUILDER.build(cfg.tokenizer) |
| image_processor = BUILDER.build(cfg.image_processor) |
| assert isinstance(image_processor, CLIPImageProcessor), \ |
| 'This conversion format only supports CLIPImageProcessor.' |
|
|
| llava_config_dict = llm.config.__dict__.copy() |
| llava_config_dict.update( |
| dict( |
| image_aspect_ratio='pad', |
| mm_hidden_size=visual_encoder.config.hidden_size, |
| mm_projector_type=f'mlp{self.projector_depth}x_gelu', |
| mm_use_im_patch_token=False, |
| mm_use_im_start_end=False, |
| mm_vision_select_feature='patch', |
| mm_vision_select_layer=self.visual_select_layer, |
| mm_vision_tower=visual_encoder.config.name_or_path, |
| unfreeze_mm_vision_tower=need_visual_encoder, |
| model_type='llava', |
| use_cache=True, |
| use_mm_proj=True)) |
|
|
| llava_config = LlavaConfig(**llava_config_dict) |
|
|
| with init_empty_weights(): |
| with warnings.catch_warnings(): |
| warnings.filterwarnings( |
| 'ignore', message='.*non-meta.*', category=UserWarning) |
| model = LlavaLlamaForCausalLM(llava_config) |
|
|
| model.load_state_dict(state_dict, strict=True, assign=True) |
|
|
| |
| print_log(f'Saving to {save_dir}', 'current') |
|
|
| model.save_pretrained(save_dir, **save_pretrained_kwargs) |
| image_processor.save_pretrained(save_dir, **save_pretrained_kwargs) |
| tokenizer.save_pretrained(save_dir, **save_pretrained_kwargs) |
|
|
| def preparing_for_generation(self, metainfo): |
| |
| assert hasattr(self, 'tokenizer'), "The Model does not have the tokenizer!!!" |
| self.bot_name = 'BOT' |
| |
| |
| stop_words = [] |
| stop_words += self.template.get('STOP_WORDS', []) |
| stop_criteria = get_stop_criteria( |
| tokenizer=self.tokenizer, stop_words=stop_words) |
| self.stop_criteria = stop_criteria |
|
|
| default_generation_kwargs = dict( |
| |
| max_new_tokens=2048, |
| do_sample=False, |
| eos_token_id=self.tokenizer.eos_token_id, |
| pad_token_id=( |
| self.tokenizer.pad_token_id |
| if self.tokenizer.pad_token_id is not None |
| else self.tokenizer.eos_token_id |
| ), |
| ) |
| default_generation_kwargs.update(metainfo.get('generation_kwargs', {})) |
| self.gen_config = GenerationConfig(**default_generation_kwargs) |
| self.init_prediction_config = True |
| return |
|
|
| def predict_forward( |
| self, pixel_values, text_prompts, **kwargs): |
| |
| |
| text_prompts = text_prompts.replace('<image>\n', '').strip() |
| |
| assert self.init_prediction_config, "Please set prediction configs using self.preparing_for_generation()" |
| |
| input_text = '' |
| input_text += self.template['INSTRUCTION'].format( |
| input=text_prompts, round=1, bot_name=self.bot_name) |
| |
| cur_encode = self.tokenizer.encode(input_text) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| ids = [IMAGE_TOKEN_INDEX] |
| ids.extend(cur_encode) |
| |
| |
| |
| |
| ids = torch.tensor(ids).cuda().unsqueeze(0) |
|
|
| pixel_values = pixel_values.cuda().unsqueeze(0) |
| |
| h, w = pixel_values.shape[-2:] |
| if max(h, w) > 1024: |
| if h > w: |
| h_new = 1024 |
| w_new = int(w * h_new / h) |
| w_new = pad_32(w_new) |
| else: |
| w_new = 1024 |
| h_new = int(h * w_new / w) |
| h_new = pad_32(h_new) |
| else: |
| h_new = pad_32(h) |
| w_new = pad_32(w) |
| dtype = pixel_values.dtype |
| pixel_values = F.interpolate(pixel_values.to(torch.float32), |
| size=(h_new, w_new), mode='bilinear', |
| align_corners=False).to(dtype) |
|
|
| mm_inputs = prepare_inputs_labels_for_multimodal_solo( |
| llm=self.llm, |
| tokenizer=self.tokenizer, |
| input_ids=ids, |
| pixel_values=pixel_values) |
|
|
| if 'input_ids' in mm_inputs.keys() and mm_inputs['input_ids'] is not None: |
| inp_length = mm_inputs['input_ids'].shape[1] |
| else: |
| inp_length = 0 |
| |
| generate_output = self.llm.generate( |
| **mm_inputs, |
| generation_config=self.gen_config, |
| streamer=None, |
| bos_token_id=self.tokenizer.bos_token_id, |
| stopping_criteria=self.stop_criteria, |
| output_hidden_states=False, |
| return_dict_in_generate=True |
| ) |
| |
| |
| |
| |
| predict = self.tokenizer.decode( |
| generate_output.sequences[0][inp_length:], skip_special_tokens=True).strip() |
| print(predict) |
| return {'prediction': predict} |
|
|
|
|
| def prepare_inputs_labels_for_multimodal_solo( |
| llm: PreTrainedModel, |
| tokenizer=None, |
| input_ids: torch.LongTensor = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| labels: Optional[torch.LongTensor] = None, |
| pixel_values: Optional[torch.FloatTensor] = None, |
| add_CLS: bool = False, |
| ): |
|
|
| ori_input_ids = input_ids |
|
|
| |
| if pixel_values is None: |
| return { |
| 'input_ids': input_ids, |
| 'position_ids': position_ids, |
| 'attention_mask': attention_mask, |
| 'past_key_values': past_key_values, |
| 'inputs_embeds': None, |
| 'labels': labels |
| } |
| |
| _labels = labels |
| _position_ids = position_ids |
| _attention_mask = attention_mask |
|
|
| vision_patch_indices = [] |
| vision_patches = [] |
| visual_tokens = [] |
|
|
| patch_size = 32 |
| NON_VISION_TOKEN = -1 |
| if isinstance(pixel_values, torch.Tensor): |
| assert pixel_values.shape[0] == 1 |
| pixel_values = pixel_values[0] |
| patches = pixel_values.unfold(1, patch_size, patch_size) \ |
| .unfold(2, patch_size, patch_size) |
| patches = patches.permute(1, 2, 0, 3, 4).contiguous() |
|
|
| n_rows, n_cols = patches.shape[:2] |
| n_patches = n_rows * n_cols |
| patches = patches.view(n_patches, -1) |
|
|
| img_tokens = ["<vision>"] |
| cur_patch_indices = [NON_VISION_TOKEN] |
| for row_idx in range(n_rows): |
| for col_idx in range(n_cols): |
| if row_idx != 0 and col_idx == 0: |
| img_tokens.append(f"<vrow_sep>") |
| cur_patch_indices.append(NON_VISION_TOKEN) |
| img_tokens.append(f"<vpatch>") |
| cur_patch_indices.append(row_idx * n_cols + col_idx) |
| |
| img_tokens.append("</vision>") |
| cur_patch_indices.append(NON_VISION_TOKEN) |
|
|
| if add_CLS: |
| |
| img_tokens.append("<|vis_cls|>") |
| cur_patch_indices.append(NON_VISION_TOKEN) |
|
|
| cur_tokens = torch.Tensor(tokenizer.convert_tokens_to_ids(img_tokens, )) |
| assert len(cur_tokens) == len(cur_patch_indices), f"{len(cur_tokens)} != {len(cur_patch_indices)}" |
|
|
| vision_patch_indices.append(torch.Tensor(cur_patch_indices).to(ori_input_ids)) |
| vision_patches.append(patches.to(pixel_values.dtype)) |
| visual_tokens.append(cur_tokens) |
|
|
| else: |
| for pixel_value in pixel_values: |
| per_image_patches = pixel_value.unfold(1, patch_size, patch_size) \ |
| .unfold(2, patch_size, patch_size) |
| per_image_patches = per_image_patches.permute(1, 2, 0, 3, |
| 4).contiguous() |
| n_rows, n_cols = per_image_patches.shape[:2] |
| n_patches = n_rows * n_cols |
| per_image_patches = per_image_patches.view(n_patches, -1) |
|
|
| img_tokens = ["<vision>"] |
| cur_patch_indices = [NON_VISION_TOKEN] |
| for row_idx in range(n_rows): |
| for col_idx in range(n_cols): |
| if row_idx != 0 and col_idx == 0: |
| img_tokens.append(f"<vrow_sep>") |
| cur_patch_indices.append(NON_VISION_TOKEN) |
| img_tokens.append(f"<vpatch>") |
| cur_patch_indices.append(row_idx * n_cols + col_idx) |
|
|
| |
| img_tokens.append("</vision>") |
| cur_patch_indices.append(NON_VISION_TOKEN) |
|
|
| if add_CLS: |
| |
| img_tokens.append("<|vis_cls|>") |
| cur_patch_indices.append(NON_VISION_TOKEN) |
|
|
| cur_tokens = torch.Tensor(tokenizer.convert_tokens_to_ids(img_tokens, )) |
| assert len(cur_tokens) == len(cur_patch_indices), f"{len(cur_tokens)} != {len(cur_patch_indices)}" |
|
|
| vision_patch_indices.append(torch.Tensor(cur_patch_indices).to(ori_input_ids)) |
| vision_patches.append(per_image_patches.to(pixel_value.dtype)) |
| visual_tokens.append(cur_tokens) |
|
|
| |
| prefix_num = 0 |
| for i in range(len(vision_patch_indices)): |
| vision_patch_indices[i] = vision_patch_indices[i] + prefix_num |
| prefix_num += len(vision_patches[i]) |
| vision_patches = torch.cat(vision_patches, dim=0) |
|
|
| if attention_mask is None: |
| attention_mask = torch.ones_like(input_ids, dtype=torch.bool) |
| else: |
| attention_mask = attention_mask.bool() |
| if position_ids is None: |
| position_ids = torch.arange( |
| 0, input_ids.shape[1], dtype=torch.long, device=input_ids.device) |
| if labels is None: |
| labels = torch.full_like(input_ids, IGNORE_INDEX) |
|
|
| |
| input_ids = [ |
| cur_input_ids[cur_attention_mask] |
| for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask) |
| ] |
| labels = [ |
| cur_labels[cur_attention_mask] |
| for cur_labels, cur_attention_mask in zip(labels, attention_mask) |
| ] |
|
|
| new_inputs_ids = [] |
| new_vision_ids = [] |
| new_labels = [] |
| cur_image_idx = 0 |
|
|
| for batch_idx, cur_input_ids in enumerate(input_ids): |
| num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() |
| if num_images == 0: |
| new_inputs_ids.append(cur_input_ids) |
| new_labels.append(labels[batch_idx]) |
| new_vision_ids.append(cur_input_ids * 0 + NON_VISION_TOKEN) |
| cur_image_idx += 1 |
| continue |
|
|
| need_replace = cur_input_ids == IMAGE_TOKEN_INDEX |
| num_replace = need_replace.sum() |
|
|
| image_token_indices = [-1] + torch.where( |
| need_replace)[0].tolist() + [ |
| cur_input_ids.shape[0] |
| ] |
| cur_input_ids_noim = [] |
| cur_labels = labels[batch_idx] |
| cur_labels_noim = [] |
| for i in range(len(image_token_indices) - 1): |
| cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + |
| 1:image_token_indices[i + |
| 1]]) |
| cur_labels_noim.append(cur_labels[image_token_indices[i] + |
| 1:image_token_indices[i + 1]]) |
| cur_new_inputs_ids = [] |
| cur_new_labels = [] |
| cur_new_vision_ids = [] |
|
|
| for i in range(num_replace + 1): |
| cur_new_inputs_ids.append(cur_input_ids_noim[i]) |
| cur_new_vision_ids.append(cur_input_ids_noim[i] * 0 + NON_VISION_TOKEN) |
| cur_new_labels.append(cur_labels_noim[i]) |
| if i < num_replace: |
| |
| cur_vision_tokens = visual_tokens[cur_image_idx].to(ori_input_ids) |
| cur_new_inputs_ids.append(cur_vision_tokens) |
| cur_new_vision_ids.append(vision_patch_indices[cur_image_idx]) |
| cur_new_labels.append( |
| torch.full((cur_vision_tokens.shape[0],), |
| IGNORE_INDEX, |
| device=cur_labels.device, |
| dtype=cur_labels.dtype)) |
| cur_image_idx += 1 |
|
|
| cur_new_inputs_ids = torch.cat(cur_new_inputs_ids) |
| cur_new_vision_ids = torch.cat(cur_new_vision_ids) |
| cur_new_labels = torch.cat(cur_new_labels) |
|
|
| new_inputs_ids.append(cur_new_inputs_ids) |
| new_vision_ids.append(cur_new_vision_ids) |
| new_labels.append(cur_new_labels) |
|
|
| |
| max_len = max(x.shape[0] for x in new_inputs_ids) |
| batch_size = len(new_inputs_ids) |
|
|
| new_inputs_ids_padded = [] |
| new_vision_ids_padded = [] |
| new_labels_padded = torch.full((batch_size, max_len), |
| IGNORE_INDEX, |
| dtype=new_labels[0].dtype, |
| device=new_labels[0].device) |
| attention_mask = torch.zeros((batch_size, max_len), |
| dtype=attention_mask.dtype, |
| device=attention_mask.device) |
| position_ids = torch.zeros((batch_size, max_len), |
| dtype=position_ids.dtype, |
| device=position_ids.device) |
|
|
| for i, (cur_new_id, |
| cur_new_labels) in enumerate(zip(new_inputs_ids, new_labels)): |
| |
| cur_vision_id = new_vision_ids[i] |
| cur_len = cur_new_id.shape[0] |
| new_inputs_ids_padded.append( |
| torch.cat((cur_new_id, |
| torch.zeros((max_len - cur_len,), |
| dtype=cur_new_id.dtype, |
| device=cur_new_id.device)), |
| dim=0)) |
| new_vision_ids_padded.append( |
| torch.cat((cur_vision_id, |
| torch.zeros((max_len - cur_len,), |
| dtype=cur_new_id.dtype, |
| device=cur_new_id.device) + NON_VISION_TOKEN), |
| dim=0)) |
| if cur_len > 0: |
| new_labels_padded[i, :cur_len] = cur_new_labels |
| attention_mask[i, :cur_len] = True |
| position_ids[i, :cur_len] = torch.arange( |
| 0, |
| cur_len, |
| dtype=position_ids.dtype, |
| device=position_ids.device) |
|
|
| new_inputs_ids = torch.stack(new_inputs_ids_padded, dim=0) |
| new_vision_ids = torch.stack(new_vision_ids_padded, dim=0) |
|
|
| if _labels is None: |
| new_labels = None |
| else: |
| new_labels = new_labels_padded |
|
|
| if _attention_mask is None: |
| attention_mask = None |
| else: |
| attention_mask = attention_mask.to(dtype=_attention_mask.dtype) |
|
|
| if _position_ids is None: |
| position_ids = None |
|
|
| vpatch_id = tokenizer.encode("<vpatch>", add_special_tokens=False)[0] |
| vpatch_indices = new_inputs_ids.clone().detach() |
| vpatch_indices[vpatch_indices != vpatch_id] = NON_VISION_TOKEN |
| if vision_patches is not None: |
| assert vision_patches.size(0) == (vpatch_indices == vpatch_id).sum().item(), \ |
| f"number of vision patches is the the same as indicated in indices: {vision_patches.size(0)} vs {(vpatch_indices == vpatch_id).sum().item()}." |
| vpatch_indices[vpatch_indices == vpatch_id] = torch.arange((vpatch_indices == vpatch_id).sum(), |
| device=vpatch_indices.device) |
|
|
| return { |
| 'input_ids': new_inputs_ids, |
| 'position_ids': position_ids, |
| 'attention_mask': attention_mask, |
| 'past_key_values': past_key_values, |
| |
| 'labels': new_labels, |
| 'vision_patch_indices': vpatch_indices, |
| 'vision_patches': vision_patches, |
| } |
|
|
|
|
| def pad_32(val): |
| if val % 32 == 0: |
| return val |
| else: |
| return (val // 32 + 1) * 32 |
|
|