| from llava.model.multimodal_encoder.processor import Blip2ImageTrainProcessor |
| from llava.model import LlavaMistralForCausalLM |
| from transformers import AutoTokenizer |
| from llava.constants import MM_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_VIDEO_PATCH_TOKEN, DEFAULT_VIDEO_TOKEN, DEFAULT_VIDEO_START_TOKEN, DEFAULT_VIDEO_END_TOKEN |
| from llava.conversation import conv_templates |
| import torch |
| from llava.mm_utils import tokenizer_image_token |
| import numpy as np |
| from PIL import Image |
| import os |
|
|
|
|
| NUM_SEGMENTS = 10 |
|
|
|
|
| def load_model(model_path, device_map): |
| kwargs = {"device_map": device_map} |
| tokenizer = AutoTokenizer.from_pretrained(model_path) |
| model = LlavaMistralForCausalLM.from_pretrained( |
| model_path, |
| low_cpu_mem_usage=True, |
| **kwargs |
| ) |
| tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_VIDEO_START_TOKEN, DEFAULT_VIDEO_END_TOKEN], special_tokens=True) |
| model.resize_token_embeddings(len(tokenizer)) |
|
|
| vision_tower = model.get_vision_tower() |
| if not vision_tower.is_loaded: |
| vision_tower.load_model(device_map=device_map) |
|
|
| image_processor = Blip2ImageTrainProcessor( |
| image_size=model.config.img_size, |
| is_training=False) |
| model.to(torch.float16) |
| return model, tokenizer, image_processor |
|
|
|
|
| def generate_input_ids(tokenizer): |
| conv = conv_templates['v1'].copy() |
| qs = "Describe the following video in detail." |
| qs = DEFAULT_VIDEO_START_TOKEN + DEFAULT_VIDEO_TOKEN + DEFAULT_VIDEO_END_TOKEN + '\n' + qs |
| conv.append_message(conv.roles[0], qs) |
| conv.append_message(conv.roles[1], None) |
| prompt = conv.get_prompt() |
| input_ids = tokenizer_image_token(prompt, tokenizer, MM_TOKEN_INDEX, return_tensors='pt').unsqueeze(0) |
| return input_ids |
|
|
|
|
| def generate_images(frame_folder, image_processor): |
| images = load_frames(frame_folder) |
| if len(images) > NUM_SEGMENTS: |
| images = uniform_sample(images, NUM_SEGMENTS) |
| |
| new_images = [] |
| for image in images: |
| image = image_processor.preprocess(image) |
| new_images.append(image) |
| if all(x.shape == new_images[0].shape for x in new_images): |
| new_images = torch.stack(new_images, dim=0) |
| return new_images |
|
|
|
|
| def uniform_sample(frames, num_segments): |
| indices = np.linspace(start=0, stop=len(frames) - 1, num=num_segments).astype(int) |
|
|
| frames = [frames[ind] for ind in indices] |
|
|
| return frames |
|
|
| def load_frames(frames_dir): |
| results = [] |
| image_files = [(int(os.path.splitext(img)[0]), img) for img in os.listdir(frames_dir) if img.endswith('jpg')] |
| image_files = sorted(image_files, key=lambda img: img[0]) |
|
|
| for frame_name in image_files: |
| image_path = f"{frames_dir}/{frame_name[1]}" |
| image = Image.open(image_path).convert('RGB') |
| results.append(image) |
| return results |
|
|
|
|
| class MASPVisionWrapper(torch.nn.Module): |
|
|
| def __init__(self, vision_tower, qformer, projector, query_tokens, frame_position_encoding, ln_vision): |
| super().__init__() |
| self.vision_tower = vision_tower |
| self.qformer = qformer |
| self.projector = projector |
| self.query_tokens = query_tokens |
| self.ln_vision = ln_vision |
| self.frame_position_encoding = frame_position_encoding |
|
|
| def forward(self, images): |
| |
| image_features = self.vision_tower(images) |
| image_features = self.ln_vision(image_features) |
| attn_mask = torch.ones(image_features.size()[:-1], dtype=torch.long).to(image_features.device) |
| query_tokens = self.query_tokens.expand(image_features.shape[0], -1, -1) |
| dtype_ = self.vision_tower.dtype |
| image_features = self.qformer.bert( |
| query_embeds= query_tokens.to(dtype_), |
| encoder_hidden_states=image_features.to(dtype_), |
| encoder_attention_mask=attn_mask, |
| return_dict=True |
| ).last_hidden_state.to(dtype_) |
| frame_ids = torch.arange(image_features.shape[0], dtype=torch.long, device=image_features.device) |
| image_features += self.frame_position_encoding(frame_ids).unsqueeze(-2) |
| return self.projector(image_features) |
|
|
|
|
| def inference(model_path, frame_folder): |
| |
| model, tokenizer, image_processor = load_model(model_path, device_map={"":0}) |
| input_ids = generate_input_ids(tokenizer)[0].to(model.device) |
| images = generate_images(frame_folder, image_processor).to(model.device) |
| vision_module = MASPVisionWrapper( |
| vision_tower=model.get_vision_tower(), |
| qformer=model.get_qformer(), |
| projector=model.get_model().mm_projector, |
| query_tokens=model.get_query_tokens(), |
| frame_position_encoding=model.get_frame_position_encoding(), |
| ln_vision=model.get_ln_vision(), |
| ) |
|
|
| |
| with torch.inference_mode(): |
| |
| image_features = vision_module(images).flatten(0, 1) |
| |
| vision_token_indice = torch.where(input_ids == MM_TOKEN_INDEX)[0][0] |
| pre_text_token = model.get_model().embed_tokens(input_ids[:vision_token_indice]) |
| post_text_token = model.get_model().embed_tokens(input_ids[vision_token_indice+1:]) |
| inputs_embeds = torch.cat([pre_text_token, image_features, post_text_token]).unsqueeze(0) |
| |
| output_ids = model.generate_from_base_class( |
| inputs_embeds=inputs_embeds, |
| do_sample=True, |
| temperature=0.01, |
| top_p=None, |
| num_beams=1, |
| max_new_tokens=1024, |
| pad_token_id=tokenizer.eos_token_id, |
| use_cache=True, |
| ) |
| output = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] |
| output = output.strip() |
| print(output) |
|
|
|
|
| if __name__ == '__main__': |
| model_path = '/mnt/bn/algo-masp-nas-2/xiangchen/model/masp_models/llava-thothv2_mar_release_all_data' |
| frame_folder = '/mnt/bn/yukunfeng-nasdrive/xiangchen/masp_data/20231110_ttp/video/v12044gd0000cl5c6rfog65i2eoqcqig' |
| inference(model_path, frame_folder) |