| import sys |
| import torch |
| import os |
| import random |
| from io import BytesIO |
| import numpy as np |
| import time |
| from llava.constants import MM_TOKEN_INDEX, DEFAULT_VIDEO_START_TOKEN, DEFAULT_VIDEO_END_TOKEN, DEFAULT_VIDEO_TOKEN, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN |
| from llava.conversation import conv_templates, SeparatorStyle |
| from llava.utils import disable_torch_init |
| from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria, process_images_v2 |
| from llava.model.builder import load_pretrained_model |
| from llava.model.multimodal_encoder.processor import Blip2ImageTrainProcessor |
| from llava.model import LlavaMistralForCausalLM |
| from llava.model.multimodal_encoder.eva_vit import create_eva_vit_g |
| import torch_neuronx |
| import torch |
| import torch_neuronx |
| from llava.model import LlavaMistralForCausalLM |
| from transformers import AutoTokenizer |
| from llava.constants import MM_TOKEN_INDEX, DEFAULT_VIDEO_START_TOKEN, DEFAULT_VIDEO_END_TOKEN, DEFAULT_VIDEO_TOKEN, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN |
|
|
|
|
| from transformers import CLIPImageProcessor |
| from PIL import Image |
| import logging |
| from qformer_tian import BertConfig, BertModel |
|
|
|
|
| def select_frames(input_frames, num_segments = 10): |
|
|
| indices = np.linspace(start=0, stop=len(input_frames)-1, num=num_segments).astype(int) |
|
|
| frames = [input_frames[ind] for ind in indices] |
|
|
| return frames |
|
|
| |
| def generate_input_ids(tokenizer): |
| conv = conv_templates['v1'].copy() |
| qs = "Describe the following video in detail." |
| qs = DEFAULT_VIDEO_START_TOKEN + DEFAULT_VIDEO_TOKEN + DEFAULT_VIDEO_END_TOKEN + '\n' + qs |
| conv.append_message(conv.roles[0], qs) |
| conv.append_message(conv.roles[1], None) |
| prompt = conv.get_prompt() |
| input_ids = tokenizer_image_token(prompt, tokenizer, MM_TOKEN_INDEX, return_tensors='pt').unsqueeze(0) |
| return input_ids, conv |
|
|
| def uniform_sample(frames, num_segments): |
| indices = np.linspace(start=0, stop=len(frames) - 1, num=num_segments).astype(int) |
| frames = [frames[ind] for ind in indices] |
| return frames |
|
|
| save_root = './inf2_weights' |
| if not os.path.isdir(save_root): |
| os.makedirs(save_root) |
|
|
| EVITG_SAVE_PATH = os.path.join(save_root, 'neuron_eva_vit_batch7.pth') |
| LAYERNORM_SAVE_PATH = os.path.join(save_root, 'ln_state_dict.pth') |
| QUERYTOKEN_SAVE_PATH = os.path.join(save_root, 'query_tokens.pth') |
| BERT_SAVE_PATH = os.path.join(save_root, 'neuron_bert.pth') |
| POSITION_ENCODING_SAVE_PATH = os.path.join(save_root, 'frame_position_encoding.pth') |
| PROJECTOR_SAVE_PATH = os.path.join(save_root, 'projector.pth') |
| EMBED_TOKENS_SAVE_PATH = os.path.join(save_root, 'embed_tokens.pth') |
|
|
|
|
| model_path = './llava-mistral_videollava_ptv12_250k_samep_only_sopv2_mistralv2_scratch/' |
| disable_torch_init() |
| |
| device_map={"":'cpu'} |
| kwargs = {"device_map": device_map} |
| kwargs['torch_dtype'] = torch.float32 |
| tokenizer = AutoTokenizer.from_pretrained(model_path) |
| model = LlavaMistralForCausalLM.from_pretrained( |
| model_path, |
| low_cpu_mem_usage=True, |
| **kwargs |
| ) |
| tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_VIDEO_START_TOKEN, DEFAULT_VIDEO_END_TOKEN], special_tokens=True) |
| model.resize_token_embeddings(len(tokenizer)) |
|
|
| model.config.vit_precision == 'fp32' |
| vision_tower = model.get_vision_tower() |
| vision_tower.is_loaded = False |
| vision_tower.load_model(device_map=device_map) |
| vision_tower = vision_tower.to(torch.float32) |
|
|
| vision_tower = vision_tower.eval() |
| print('vision tower hiidden size') |
| print(vision_tower.hidden_size) |
|
|
| batch_size=7 |
| img_size=224 |
| input_shape = (batch_size, 3, img_size, img_size) |
| input_data=torch.zeros(input_shape, dtype=torch.float32) |
| model_neuronx = torch_neuronx.trace(vision_tower, input_data, compiler_args=["--model-type=transformer"]) |
| model_neuronx.save(EVITG_SAVE_PATH) |
|
|
| image_processor = Blip2ImageTrainProcessor( |
| image_size=model.config.img_size, |
| is_training=False) |
|
|
| input_ids, conv = generate_input_ids(tokenizer) |
| device = torch.device('cpu') |
| model = model.to(device) |
| conv_mode = 'v1' |
| NUM_SEGMENTS = 10 |
|
|
| video_dir = './v12044gd0000cl5c6rfog65i2eoqcqig' |
| frames = [(int(os.path.splitext(item)[0]), os.path.join(video_dir, item)) for item in os.listdir(video_dir)] |
| frames = [item[1] for item in sorted(frames, key=lambda x: x[0])] |
| images = [Image.open(frame).convert('RGB') for frame in frames] |
| images = uniform_sample(images, NUM_SEGMENTS) |
| images = process_images_v2(images, image_processor, model.config) |
|
|
| |
| ln_vision = model.get_ln_vision() |
| ln_vision = ln_vision.eval() |
| ln_state_dict = ln_vision.state_dict() |
| torch.save(ln_state_dict, LAYERNORM_SAVE_PATH) |
|
|
|
|
| query_tokens = model.get_query_tokens() |
| |
| query_tokens_state_dict = {'query_tokens': query_tokens.data} |
| torch.save(query_tokens_state_dict, QUERYTOKEN_SAVE_PATH) |
|
|
| |
| qformer = model.get_qformer() |
| bert_torch = qformer.bert |
| bert_torch = bert_torch.eval() |
| bert_torch = bert_torch.to(torch.float32) |
|
|
|
|
| vision_width = 1408 |
| cross_attention_freq = 2 |
| num_query_token = 32 |
| encoder_config = BertConfig.from_pretrained("bert-base-uncased") |
| encoder_config.encoder_width = vision_width |
| |
| encoder_config.add_cross_attention = True |
| encoder_config.cross_attention_freq = cross_attention_freq |
| encoder_config.query_length = num_query_token |
| bert = BertModel(encoder_config, add_pooling_layer=False) |
| bert.embeddings.word_embeddings = None |
| bert.embeddings.position_embeddings = None |
|
|
| for layer in bert.encoder.layer: |
| layer.output = None |
| layer.intermediate = None |
|
|
|
|
| bert.load_state_dict(bert_torch.state_dict()) |
| bert = bert.eval() |
|
|
| input_example = ( |
| torch.zeros(70, 32, 768, dtype=torch.float32), |
| torch.zeros(70, 256, 1408, dtype=torch.float32), |
| torch.zeros(70, 256, dtype=torch.int64) |
| ) |
| neuron_bert = torch_neuronx.trace(bert, input_example) |
| neuron_bert.save(BERT_SAVE_PATH) |
|
|
| |
| frame_position_encoding = model.get_frame_position_encoding() |
| projector = model.get_model().mm_projector |
|
|
| frame_position_encoding = frame_position_encoding.eval() |
| frame_position_encoding = frame_position_encoding.to(torch.float32) |
|
|
| projector = projector.eval() |
| projector = projector.to(torch.float32) |
|
|
| torch.save(frame_position_encoding.state_dict(), POSITION_ENCODING_SAVE_PATH) |
| torch.save(projector.state_dict(), PROJECTOR_SAVE_PATH) |
|
|
| |
| embed_tokens = model.get_model().embed_tokens |
| embed_tokens = embed_tokens.eval() |
| embed_tokens = embed_tokens.to(torch.float32) |
| torch.save(embed_tokens.state_dict(), EMBED_TOKENS_SAVE_PATH) |