| from typing import Any, List, Optional, Tuple, Union |
| import torch |
| from peft import LoraConfig, get_peft_model |
| from torch import nn |
| from torch.nn import CrossEntropyLoss |
| from transformers import (AutoModel, GenerationConfig, Qwen3ForCausalLM) |
| from transformers.modeling_outputs import CausalLMOutputWithPast |
| from transformers.modeling_utils import PreTrainedModel |
| from transformers.utils import ModelOutput, logging |
| from transformers import StoppingCriteriaList, StoppingCriteria |
|
|
| from .configuration_vectorllm import VectorLLMConfig, ProjectorConfig |
| from .configuration_dinov3_vit import DINOv3ViTConfig |
| from .modeling_dinov3_vit import DINOv3ViTModel |
| from .image_processing_vectorllm import VectorLLMImageProcessor |
| from .processing_vectorllm import VectorLLMProcessor |
| from transformers.activations import ACT2FN |
|
|
| logger = logging.get_logger(__name__) |
|
|
| class ProjectorModel(PreTrainedModel): |
| _auto_class = "AutoModel" |
| config_class = ProjectorConfig |
| base_model_prefix = "model" |
| supports_gradient_checkpointing = True |
|
|
| def __init__(self, config: ProjectorConfig) -> None: |
| super().__init__(config) |
| self.gradient_checkpointing = False |
|
|
| modules = [ |
| nn.Linear( |
| config.visual_hidden_size, config.llm_hidden_size, bias=config.bias |
| ) |
| ] |
| for _ in range(1, config.depth): |
| modules.append(ACT2FN[config.hidden_act]) |
| modules.append( |
| nn.Linear( |
| config.llm_hidden_size, config.llm_hidden_size, bias=config.bias |
| ) |
| ) |
| self.model = nn.Sequential(*modules) |
|
|
| def enable_input_require_grads(self): |
| def make_inputs_require_grad(module, input, output): |
| output.requires_grad_(True) |
| self.model.register_forward_hook(make_inputs_require_grad) |
|
|
| def _set_gradient_checkpointing(self, module, value=False): |
| if isinstance(module, ProjectorModel): |
| module.gradient_checkpointing = value |
|
|
| def forward(self, x): |
| if self.gradient_checkpointing and self.training: |
| layer_outputs = torch.utils.checkpoint.checkpoint(self.model, x) |
| else: |
| layer_outputs = self.model(x) |
| return layer_outputs |
|
|
| class StopWordStoppingCriteria(StoppingCriteria): |
| """StopWord stopping criteria.""" |
|
|
| def __init__(self, tokenizer, stop_word): |
| self.tokenizer = tokenizer |
| self.stop_word = stop_word |
| self.length = len(self.stop_word) |
|
|
| def __call__(self, input_ids, *args, **kwargs) -> bool: |
| cur_text = self.tokenizer.decode(input_ids[0]) |
| cur_text = cur_text.replace('\r', '').replace('\n', '') |
| return cur_text[-self.length:] == self.stop_word |
|
|
| def get_stop_criteria( |
| tokenizer, |
| stop_words=[], |
| ): |
| stop_criteria = StoppingCriteriaList() |
| for word in stop_words: |
| stop_criteria.append(StopWordStoppingCriteria(tokenizer, word)) |
| return stop_criteria |
|
|
| class VectorLLMWrapModel(PreTrainedModel): |
| config_class = VectorLLMConfig |
| main_input_name = 'pixel_values' |
| base_model_prefix = 'language_model' |
| _no_split_modules = ['DINOv3ViTModel', 'Qwen3DecoderLayer'] |
| _supports_flash_attn_2 = True |
| supports_gradient_checkpointing = True |
|
|
| def __init__( |
| self, config: VectorLLMConfig, vision_model=None, language_model=None, |
| projector=None, pos_embeds=None, use_flash_attn=True, vectorllm_model=None, |
| ): |
| super().__init__(config) |
| use_flash_attn = use_flash_attn |
| config.vision_config.use_flash_attn = True if use_flash_attn else False |
| config.llm_config._attn_implementation = 'flash_attention_2' if use_flash_attn else 'eager' |
|
|
| vit_hidden_size = config.vision_hidden_size |
| llm_hidden_size = config.hidden_size |
| self.vit_hidden_size = vit_hidden_size |
| self.llm_hidden_size = llm_hidden_size |
|
|
| self.pixel_idx = config.pixel_idx |
| self.num_cls_register_tokens = config.num_cls_register_tokens |
|
|
| if vectorllm_model is None: |
| self.model = VectorLLMModel( |
| config=config, vision_model=vision_model, |
| language_model=language_model, projector=projector, |
| pos_embeds=pos_embeds, use_flash_attn=use_flash_attn |
| ) |
| else: |
| self.model = vectorllm_model |
|
|
| @property |
| def lm_head(self): |
| return self.model.get_output_embeddings() |
|
|
| def get_input_embeddings(self): |
| return self.model.get_input_embeddings() |
|
|
| def get_output_embeddings(self): |
| return self.model.get_output_embeddings() |
|
|
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| pixel_values: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| ): |
|
|
| return self.model.forward( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| cache_position=cache_position, |
| pixel_values=pixel_values, |
| labels=labels, |
| ) |
|
|
| @torch.no_grad() |
| def generate( |
| self, |
| pixel_values: Optional[torch.FloatTensor] = None, |
| input_ids: Optional[torch.FloatTensor] = None, |
| attention_mask: Optional[torch.LongTensor] = None, |
| visual_features: Optional[torch.FloatTensor] = None, |
| generation_config: Optional[GenerationConfig] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| **generate_kwargs, |
| ) -> torch.LongTensor: |
|
|
| return self.model.generate( |
| pixel_values=pixel_values, |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| visual_features=visual_features, |
| generation_config=generation_config, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| **generate_kwargs, |
| ) |
|
|
| class VectorLLMModel(PreTrainedModel): |
| config_class = VectorLLMConfig |
| main_input_name = 'pixel_values' |
| base_model_prefix = 'language_model' |
| _no_split_modules = ['DINOv3ViTModel', 'Qwen3DecoderLayer'] |
| _supports_flash_attn_2 = True |
| supports_gradient_checkpointing = True |
|
|
| def __init__( |
| self, config: VectorLLMConfig, vision_model=None, language_model=None, |
| projector=None, pos_embeds=None, use_flash_attn=True |
| ): |
| super().__init__(config) |
| use_flash_attn = use_flash_attn |
| config.vision_config.use_flash_attn = True if use_flash_attn else False |
| config.llm_config._attn_implementation = 'flash_attention_2' if use_flash_attn else 'eager' |
|
|
| if vision_model is not None: |
| self.vision_model = vision_model |
| else: |
| self.vision_model = DINOv3ViTModel(config.vision_config) |
| if language_model is not None: |
| self.language_model = language_model |
| else: |
| self.language_model = Qwen3ForCausalLM(config.llm_config) |
| vit_hidden_size = config.vision_hidden_size |
| llm_hidden_size = config.hidden_size |
| self.vit_hidden_size = vit_hidden_size |
| self.llm_hidden_size = llm_hidden_size |
|
|
| if projector is not None: |
| self.projector = projector |
| else: |
| self.projector = ProjectorModel(config.projector_config) |
|
|
| w, h = (config.regression_size[0] // 16, config.regression_size[1] // 16) |
| n_pos = w * h |
|
|
| if pos_embeds is not None: |
| self.visual_pos_embeddings = pos_embeds |
| else: |
| self.visual_pos_embeddings = nn.Embedding(n_pos, self.vit_hidden_size) |
| self.pixel_idx = config.pixel_idx |
| self.num_cls_register_tokens = config.num_cls_register_tokens |
|
|
| def wrap_backbone_lora(self, r=128, lora_alpha=256, lora_dropout=0.05): |
| lora_config = LoraConfig( |
| r=r, |
| target_modules=['attn.qkv', 'attn.proj', 'mlp.fc1', 'mlp.fc2'], |
| lora_alpha=lora_alpha, |
| lora_dropout=lora_dropout, |
| ) |
| self.vision_model = get_peft_model(self.vision_model, lora_config) |
| self.vision_model.print_trainable_parameters() |
|
|
| def wrap_llm_lora(self, r=128, lora_alpha=256, lora_dropout=0.05): |
| |
| target_modules = ['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.o_proj', |
| 'mlp.gate_proj', 'mlp.down_proj', 'mlp.up_proj'] |
| lora_config = LoraConfig( |
| r=r, |
| target_modules=target_modules, |
| lora_alpha=lora_alpha, |
| lora_dropout=lora_dropout, |
| task_type='CAUSAL_LM' |
| ) |
| self.language_model = get_peft_model(self.language_model, lora_config) |
| self.language_model.enable_input_require_grads() |
| self.language_model.print_trainable_parameters() |
|
|
| def extract_feature(self, pixel_values): |
| features = self.vision_model(pixel_values).last_hidden_state[:, self.num_cls_register_tokens:, :] |
| features.requires_grad_(True) |
|
|
| pos_embed = self.visual_pos_embeddings.weight.unsqueeze(0) |
| pos_embed = pos_embed.repeat(features.shape[0], 1, 1) |
| features = features + pos_embed |
|
|
| return features |
|
|
| @property |
| def lm_head(self): |
| return self.language_model.get_output_embeddings() |
|
|
| def get_input_embeddings(self): |
| return self.language_model.get_input_embeddings() |
|
|
| def get_output_embeddings(self): |
| return self.language_model.get_output_embeddings() |
|
|
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| pixel_values: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| ): |
|
|
| if type(pixel_values) is list or pixel_values.ndim == 5: |
| if type(pixel_values) is list: |
| pixel_values = [ |
| x.unsqueeze(0) if x.ndim == 3 else x for x in pixel_values |
| ] |
| |
| concat_images = torch.cat( |
| [image.to(self.vision_model.dtype) for image in pixel_values], dim=0) |
| elif pixel_values.ndim == 4: |
| concat_images = pixel_values.to(self.vision_model.dtype) |
| else: |
| raise NotImplementedError() |
|
|
| input_ids = input_ids |
| position_ids = position_ids |
| attention_mask = attention_mask |
| |
| image_flags = torch.sum(concat_images, dim=(1, 2, 3)) != 0 |
| image_flags = image_flags.long() |
|
|
| labels = labels |
| use_cache = use_cache if use_cache is not None else False |
|
|
| outputs = self._llm_forward( |
| input_ids=input_ids, |
| position_ids=position_ids, |
| attention_mask=attention_mask, |
| past_key_values=past_key_values, |
| image_flags=image_flags, |
| pixel_values=concat_images, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| labels=labels, |
| use_cache=use_cache, |
| ) |
|
|
| return outputs |
|
|
| def _llm_forward( |
| self, |
| pixel_values: torch.FloatTensor, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| image_flags: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple, CausalLMOutputWithPast]: |
| return_dict = return_dict if return_dict is not None \ |
| else self.config.use_return_dict |
|
|
| image_flags = image_flags.squeeze(-1) |
| |
| input_embeds = self.language_model.get_input_embeddings()( |
| input_ids).clone() |
|
|
| vit_embeds = self.extract_feature(pixel_values) |
| vit_embeds = vit_embeds.to(input_embeds.dtype) |
|
|
| vit_embeds = vit_embeds[image_flags == 1] |
|
|
| B, N, C = input_embeds.shape |
| input_embeds = input_embeds.reshape(B * N, C) |
| vit_embeds = vit_embeds.to(input_embeds.dtype) |
|
|
| input_ids = input_ids.reshape(B * N) |
| selected = (input_ids == self.pixel_idx) |
|
|
| try: |
| input_embeds[selected] = vit_embeds.reshape(-1, C) |
| except Exception as e: |
| vit_embeds = vit_embeds.reshape(-1, C) |
| print(f'warning: {e}, input_embeds[selected].shape=' |
| f'{input_embeds[selected].shape}, ' |
| f'vit_embeds.shape={vit_embeds.shape}') |
| n_token = selected.sum() |
| if n_token > len(vit_embeds): |
| print(f"Wrong !!! {n_token} image tokens in text but only {len(vit_embeds)} vit embeds !!!") |
| expand_ratio = n_token // len(vit_embeds) + 1 |
| vit_embeds = torch.cat([vit_embeds] * expand_ratio, dim=0) |
|
|
| input_embeds[selected] = vit_embeds[:n_token] |
|
|
| input_embeds = input_embeds.reshape(B, N, C) |
|
|
| outputs = self.language_model( |
| inputs_embeds=input_embeds, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
| logits = outputs.logits |
|
|
| loss = None |
| if labels is not None: |
| |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| |
| loss_fct = CrossEntropyLoss() |
| shift_logits = shift_logits.view( |
| -1, self.language_model.config.vocab_size) |
| shift_labels = shift_labels.view(-1) |
| |
| shift_labels = shift_labels.to(shift_logits.device) |
| loss = loss_fct(shift_logits, shift_labels) |
|
|
| if not return_dict: |
| output = (logits,) + outputs[1:] |
| return (loss,) + output if loss is not None else output |
|
|
| return CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
| @torch.no_grad() |
| def generate( |
| self, |
| pixel_values: Optional[torch.FloatTensor] = None, |
| input_ids: Optional[torch.FloatTensor] = None, |
| attention_mask: Optional[torch.LongTensor] = None, |
| visual_features: Optional[torch.FloatTensor] = None, |
| generation_config: Optional[GenerationConfig] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| **generate_kwargs, |
| ) -> torch.LongTensor: |
| device = self.device |
|
|
| if pixel_values is not None: |
| if visual_features is not None: |
| vit_embeds = visual_features |
| else: |
| if type(pixel_values) is list or pixel_values.ndim == 5: |
| if type(pixel_values) is list: |
| pixel_values = [ |
| x.unsqueeze(0) if x.ndim == 3 else x for x in pixel_values |
| ] |
| |
| pixel_values = torch.cat( |
| [image.to(self.vision_model.dtype) for image in pixel_values], dim=0) |
|
|
| vit_embeds = self.extract_feature(pixel_values.to(device)) |
| image_flags = torch.sum(pixel_values, dim=(1, 2, 3)) != 0 |
| image_flags = image_flags.long() |
| vit_embeds = vit_embeds[image_flags == 1] |
|
|
| input_embeds = self.language_model.get_input_embeddings()(input_ids.to(device)) |
| vit_embeds = vit_embeds.to(input_embeds.dtype) |
| B, N, C = input_embeds.shape |
| input_embeds = input_embeds.reshape(B * N, C) |
|
|
| input_ids = input_ids.reshape(B * N) |
| selected = (input_ids == self.pixel_idx) |
| assert selected.sum() != 0 |
| input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device) |
| input_embeds = input_embeds.reshape(B, N, C) |
| else: |
| input_embeds = self.language_model.get_input_embeddings()(input_ids) |
| outputs = self.language_model.generate( |
| inputs_embeds=input_embeds, |
| attention_mask=attention_mask.to(device), |
| generation_config=generation_config, |
| output_hidden_states=output_hidden_states, |
| |
| |
| |
| **generate_kwargs, |
| ) |
|
|
| return outputs |