vectorllm-hf / modeling_vectorllm.py
insomnia7's picture
Upload folder using huggingface_hub
a04bbbc verified
from typing import Any, List, Optional, Tuple, Union
import torch
from peft import LoraConfig, get_peft_model
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers import (AutoModel, GenerationConfig, Qwen3ForCausalLM)
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import ModelOutput, logging
from transformers import StoppingCriteriaList, StoppingCriteria
from .configuration_vectorllm import VectorLLMConfig, ProjectorConfig
from .configuration_dinov3_vit import DINOv3ViTConfig
from .modeling_dinov3_vit import DINOv3ViTModel
from .image_processing_vectorllm import VectorLLMImageProcessor
from .processing_vectorllm import VectorLLMProcessor
from transformers.activations import ACT2FN
logger = logging.get_logger(__name__)
class ProjectorModel(PreTrainedModel):
_auto_class = "AutoModel"
config_class = ProjectorConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
def __init__(self, config: ProjectorConfig) -> None:
super().__init__(config)
self.gradient_checkpointing = False
modules = [
nn.Linear(
config.visual_hidden_size, config.llm_hidden_size, bias=config.bias
)
]
for _ in range(1, config.depth):
modules.append(ACT2FN[config.hidden_act])
modules.append(
nn.Linear(
config.llm_hidden_size, config.llm_hidden_size, bias=config.bias
)
)
self.model = nn.Sequential(*modules)
def enable_input_require_grads(self):
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
self.model.register_forward_hook(make_inputs_require_grad)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, ProjectorModel):
module.gradient_checkpointing = value
def forward(self, x):
if self.gradient_checkpointing and self.training:
layer_outputs = torch.utils.checkpoint.checkpoint(self.model, x)
else:
layer_outputs = self.model(x)
return layer_outputs
class StopWordStoppingCriteria(StoppingCriteria):
"""StopWord stopping criteria."""
def __init__(self, tokenizer, stop_word):
self.tokenizer = tokenizer
self.stop_word = stop_word
self.length = len(self.stop_word)
def __call__(self, input_ids, *args, **kwargs) -> bool:
cur_text = self.tokenizer.decode(input_ids[0])
cur_text = cur_text.replace('\r', '').replace('\n', '')
return cur_text[-self.length:] == self.stop_word
def get_stop_criteria(
tokenizer,
stop_words=[],
):
stop_criteria = StoppingCriteriaList()
for word in stop_words:
stop_criteria.append(StopWordStoppingCriteria(tokenizer, word))
return stop_criteria
class VectorLLMWrapModel(PreTrainedModel):
config_class = VectorLLMConfig
main_input_name = 'pixel_values'
base_model_prefix = 'language_model'
_no_split_modules = ['DINOv3ViTModel', 'Qwen3DecoderLayer']
_supports_flash_attn_2 = True
supports_gradient_checkpointing = True
def __init__(
self, config: VectorLLMConfig, vision_model=None, language_model=None,
projector=None, pos_embeds=None, use_flash_attn=True, vectorllm_model=None,
):
super().__init__(config)
use_flash_attn = use_flash_attn
config.vision_config.use_flash_attn = True if use_flash_attn else False
config.llm_config._attn_implementation = 'flash_attention_2' if use_flash_attn else 'eager'
vit_hidden_size = config.vision_hidden_size
llm_hidden_size = config.hidden_size
self.vit_hidden_size = vit_hidden_size
self.llm_hidden_size = llm_hidden_size
self.pixel_idx = config.pixel_idx
self.num_cls_register_tokens = config.num_cls_register_tokens
if vectorllm_model is None:
self.model = VectorLLMModel(
config=config, vision_model=vision_model,
language_model=language_model, projector=projector,
pos_embeds=pos_embeds, use_flash_attn=use_flash_attn
)
else:
self.model = vectorllm_model
@property
def lm_head(self):
return self.model.get_output_embeddings()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def get_output_embeddings(self):
return self.model.get_output_embeddings()
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
):
return self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
pixel_values=pixel_values,
labels=labels,
)
@torch.no_grad()
def generate(
self,
pixel_values: Optional[torch.FloatTensor] = None,
input_ids: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
visual_features: Optional[torch.FloatTensor] = None,
generation_config: Optional[GenerationConfig] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**generate_kwargs,
) -> torch.LongTensor:
return self.model.generate(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
visual_features=visual_features,
generation_config=generation_config,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**generate_kwargs,
)
class VectorLLMModel(PreTrainedModel):
config_class = VectorLLMConfig
main_input_name = 'pixel_values'
base_model_prefix = 'language_model'
_no_split_modules = ['DINOv3ViTModel', 'Qwen3DecoderLayer']
_supports_flash_attn_2 = True
supports_gradient_checkpointing = True
def __init__(
self, config: VectorLLMConfig, vision_model=None, language_model=None,
projector=None, pos_embeds=None, use_flash_attn=True
):
super().__init__(config)
use_flash_attn = use_flash_attn
config.vision_config.use_flash_attn = True if use_flash_attn else False
config.llm_config._attn_implementation = 'flash_attention_2' if use_flash_attn else 'eager'
if vision_model is not None:
self.vision_model = vision_model
else:
self.vision_model = DINOv3ViTModel(config.vision_config)
if language_model is not None:
self.language_model = language_model
else:
self.language_model = Qwen3ForCausalLM(config.llm_config)
vit_hidden_size = config.vision_hidden_size
llm_hidden_size = config.hidden_size
self.vit_hidden_size = vit_hidden_size
self.llm_hidden_size = llm_hidden_size
if projector is not None:
self.projector = projector
else:
self.projector = ProjectorModel(config.projector_config)
w, h = (config.regression_size[0] // 16, config.regression_size[1] // 16)
n_pos = w * h
if pos_embeds is not None:
self.visual_pos_embeddings = pos_embeds
else:
self.visual_pos_embeddings = nn.Embedding(n_pos, self.vit_hidden_size)
self.pixel_idx = config.pixel_idx
self.num_cls_register_tokens = config.num_cls_register_tokens
def wrap_backbone_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
lora_config = LoraConfig(
r=r,
target_modules=['attn.qkv', 'attn.proj', 'mlp.fc1', 'mlp.fc2'],
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
)
self.vision_model = get_peft_model(self.vision_model, lora_config)
self.vision_model.print_trainable_parameters()
def wrap_llm_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
# Determine the target modules based on the architecture of the language model
target_modules = ['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.o_proj',
'mlp.gate_proj', 'mlp.down_proj', 'mlp.up_proj']
lora_config = LoraConfig(
r=r,
target_modules=target_modules,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
task_type='CAUSAL_LM'
)
self.language_model = get_peft_model(self.language_model, lora_config)
self.language_model.enable_input_require_grads()
self.language_model.print_trainable_parameters()
def extract_feature(self, pixel_values):
features = self.vision_model(pixel_values).last_hidden_state[:, self.num_cls_register_tokens:, :] # (B, N, C)
features.requires_grad_(True)
pos_embed = self.visual_pos_embeddings.weight.unsqueeze(0)
pos_embed = pos_embed.repeat(features.shape[0], 1, 1)
features = features + pos_embed
return features
@property
def lm_head(self):
return self.language_model.get_output_embeddings()
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
def get_output_embeddings(self):
return self.language_model.get_output_embeddings()
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
):
if type(pixel_values) is list or pixel_values.ndim == 5:
if type(pixel_values) is list:
pixel_values = [
x.unsqueeze(0) if x.ndim == 3 else x for x in pixel_values
]
# b*n, c, h, w
concat_images = torch.cat(
[image.to(self.vision_model.dtype) for image in pixel_values], dim=0)
elif pixel_values.ndim == 4:
concat_images = pixel_values.to(self.vision_model.dtype)
else:
raise NotImplementedError()
input_ids = input_ids
position_ids = position_ids
attention_mask = attention_mask
# sum is 0 are text
image_flags = torch.sum(concat_images, dim=(1, 2, 3)) != 0
image_flags = image_flags.long()
labels = labels
use_cache = use_cache if use_cache is not None else False
outputs = self._llm_forward(
input_ids=input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
image_flags=image_flags,
pixel_values=concat_images,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
use_cache=use_cache,
)
return outputs
def _llm_forward(
self,
pixel_values: torch.FloatTensor,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
image_flags: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
return_dict = return_dict if return_dict is not None \
else self.config.use_return_dict
image_flags = image_flags.squeeze(-1)
# We only added the clone code here to avoid the error.
input_embeds = self.language_model.get_input_embeddings()(
input_ids).clone()
vit_embeds = self.extract_feature(pixel_values)
vit_embeds = vit_embeds.to(input_embeds.dtype)
vit_embeds = vit_embeds[image_flags == 1]
B, N, C = input_embeds.shape
input_embeds = input_embeds.reshape(B * N, C)
vit_embeds = vit_embeds.to(input_embeds.dtype)
input_ids = input_ids.reshape(B * N)
selected = (input_ids == self.pixel_idx)
try:
input_embeds[selected] = vit_embeds.reshape(-1, C)
except Exception as e:
vit_embeds = vit_embeds.reshape(-1, C)
print(f'warning: {e}, input_embeds[selected].shape='
f'{input_embeds[selected].shape}, '
f'vit_embeds.shape={vit_embeds.shape}')
n_token = selected.sum()
if n_token > len(vit_embeds):
print(f"Wrong !!! {n_token} image tokens in text but only {len(vit_embeds)} vit embeds !!!")
expand_ratio = n_token // len(vit_embeds) + 1
vit_embeds = torch.cat([vit_embeds] * expand_ratio, dim=0)
input_embeds[selected] = vit_embeds[:n_token]
input_embeds = input_embeds.reshape(B, N, C)
outputs = self.language_model(
inputs_embeds=input_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
logits = outputs.logits
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(
-1, self.language_model.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@torch.no_grad()
def generate(
self,
pixel_values: Optional[torch.FloatTensor] = None,
input_ids: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
visual_features: Optional[torch.FloatTensor] = None,
generation_config: Optional[GenerationConfig] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**generate_kwargs,
) -> torch.LongTensor:
device = self.device
if pixel_values is not None:
if visual_features is not None:
vit_embeds = visual_features
else:
if type(pixel_values) is list or pixel_values.ndim == 5:
if type(pixel_values) is list:
pixel_values = [
x.unsqueeze(0) if x.ndim == 3 else x for x in pixel_values
]
# b*n, c, h, w
pixel_values = torch.cat(
[image.to(self.vision_model.dtype) for image in pixel_values], dim=0)
vit_embeds = self.extract_feature(pixel_values.to(device))
image_flags = torch.sum(pixel_values, dim=(1, 2, 3)) != 0
image_flags = image_flags.long()
vit_embeds = vit_embeds[image_flags == 1]
input_embeds = self.language_model.get_input_embeddings()(input_ids.to(device))
vit_embeds = vit_embeds.to(input_embeds.dtype)
B, N, C = input_embeds.shape
input_embeds = input_embeds.reshape(B * N, C)
input_ids = input_ids.reshape(B * N)
selected = (input_ids == self.pixel_idx)
assert selected.sum() != 0
input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)
input_embeds = input_embeds.reshape(B, N, C)
else:
input_embeds = self.language_model.get_input_embeddings()(input_ids)
outputs = self.language_model.generate(
inputs_embeds=input_embeds,
attention_mask=attention_mask.to(device),
generation_config=generation_config,
output_hidden_states=output_hidden_states,
# return_dict=return_dict,
# use_cache=True,
# return_dict_in_generate=True,
**generate_kwargs,
)
return outputs