| import cv2 |
| import numpy as np |
| from PIL import Image |
| from concurrent.futures import ThreadPoolExecutor |
| from config.configu import * |
| from models.model import * |
| from models.similarity import * |
| from sklearn.cluster import KMeans |
| from utils.utils import * |
| import warnings |
| from typing import Any, List, Optional, Tuple, Union |
| import torch |
| import random |
| import torch.utils.checkpoint |
| import transformers |
| from torch import nn |
| from torch.nn import CrossEntropyLoss |
| from transformers import (AutoModel, GenerationConfig, LlamaForCausalLM, |
| LlamaTokenizer) |
| from transformers.modeling_outputs import CausalLMOutputWithPast |
| from transformers.modeling_utils import PreTrainedModel |
| from transformers.utils import ModelOutput, logging |
|
|
| from .configuration_internvl_chat import InternVLChatConfig |
| from .conversation import get_conv_template |
| from .modeling_intern_vit import InternVisionModel |
| from .modeling_internlm2 import InternLM2ForCausalLM |
|
|
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
| def coord_transform(box,return_4=True): |
| if return_4: |
| return [box[0][0],box[0][1],box[1][0],box[1][1]] |
| else: |
| return [[box[0],box[1]],[box[2],box[3]]] |
| def insert_zeros(input_ids, attention_mask, num_zeros=5): |
|
|
| device = input_ids.device |
| input_ids = input_ids.cpu().clone() |
| attention_mask = attention_mask.cpu().clone() |
|
|
| for _ in range(num_zeros): |
| |
| insert_pos = random.randint(0, input_ids.size(1)) |
| |
| |
| input_ids = torch.cat((input_ids[:, :insert_pos], torch.tensor([[0]]), input_ids[:, insert_pos:]), dim=1) |
| |
| |
| attention_mask = torch.cat((attention_mask[:, :insert_pos], torch.tensor([[1]]), attention_mask[:, insert_pos:]), dim=1) |
|
|
| |
| input_ids = input_ids.to(device) |
| attention_mask = attention_mask.to(device) |
|
|
| return input_ids, attention_mask |
|
|
|
|
| def add_Gaussian_noise(input_embeds, rate=1e-1): |
|
|
| device = input_embeds.device |
| input_embeds = input_embeds.cpu().clone() |
|
|
| mean = input_embeds.mean() |
| std = input_embeds.std() |
| noise = torch.randn(input_embeds.size()) * std + mean |
| noisy_input_embeds = input_embeds + rate * noise |
|
|
| noisy_input_embeds = noisy_input_embeds.to(device) |
| noisy_input_embeds = noisy_input_embeds.to(torch.bfloat16) |
|
|
| return noisy_input_embeds |
|
|
|
|
| def version_cmp(v1, v2, op='eq'): |
| import operator |
|
|
| from packaging import version |
| op_func = getattr(operator, op) |
| return op_func(version.parse(v1), version.parse(v2)) |
|
|
| def most_frequent_rgb(image_array): |
| """找一张图片中最frequent的rgb,用于填充mask""" |
| |
| pixels = image_array.reshape(-1, image_array.shape[-1]) |
| |
| |
| unique_pixels, counts = np.unique(pixels, axis=0, return_counts=True) |
| |
| |
| most_frequent_index = np.argmax(counts) |
| |
| |
| most_frequent_pixel = unique_pixels[most_frequent_index] |
| frequency = counts[most_frequent_index] |
| return most_frequent_pixel, frequency |
|
|
| def most_frequent_rgb_fast(image_array): |
| """快速查找图片中最频繁的RGB值,不返回频率""" |
| |
| flattened = image_array.reshape(-1, 3) |
| rgb_ints = flattened[:, 0] * 256**2 + flattened[:, 1] * 256 + flattened[:, 2] |
|
|
| |
| counts = np.bincount(rgb_ints) |
|
|
| |
| most_frequent_index = np.argmax(counts) |
|
|
| |
| r = (most_frequent_index // 256**2) % 256 |
| g = (most_frequent_index // 256) % 256 |
| b = most_frequent_index % 256 |
|
|
| return (r, g, b) |
|
|
|
|
|
|
| def mask_area(image_array,coords,color): |
| """对一张图片在框定的一系列box进行mask""" |
| |
| |
| for coord in coords: |
| x1, y1, x2, y2 = coord |
| image_array[y1:y2, x1:x2] =color |
|
|
| return image_array |
|
|
|
|
| class InternVLChatModel(PreTrainedModel): |
| config_class = InternVLChatConfig |
| main_input_name = 'pixel_values' |
| _supports_flash_attn_2 = True |
| _no_split_modules = ['InternVisionModel', 'LlamaDecoderLayer', 'InternLM2DecoderLayer'] |
|
|
| def __init__(self, config: InternVLChatConfig, vision_model=None, language_model=None): |
| super().__init__(config) |
|
|
| assert version_cmp(transformers.__version__, '4.36.2', 'ge') |
| image_size = config.force_image_size or config.vision_config.image_size |
| patch_size = config.vision_config.patch_size |
| self.patch_size = patch_size |
| self.select_layer = config.select_layer |
| self.template = config.template |
| |
| self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio ** 2)) |
| |
| self.downsample_ratio = config.downsample_ratio |
| self.ps_version = config.ps_version |
|
|
|
|
| |
| self.mu_sigma=torch.load(NORM_PARAMS_PATH)['weight'] |
| self.mu=self.mu_sigma[:,0].reshape((-1,1)) |
| self.sigma=self.mu_sigma[:,1].reshape((-1,1)) |
| self.normed_emb,self.mu_sigma=self.load_normed_tok_embeddings(load_checkboard=True) |
| self.resampler=load_perceiver_resampler_2(PERCEIVER_CHECKPOINT,num_layers=4) |
| |
| self.sorter=load_orderformer(ORDERFORMER_CHECKPOINT) |
|
|
|
|
| logger.info(f'num_image_token: {self.num_image_token}') |
| logger.info(f'ps_version: {self.ps_version}') |
| |
| |
| |
| if vision_model is not None: |
| self.vision_model = vision_model |
| else: |
| self.vision_model = InternVisionModel(config.vision_config) |
| if language_model is not None: |
| self.language_model = language_model |
| else: |
| if config.llm_config.architectures[0] == 'LlamaForCausalLM': |
| self.language_model = LlamaForCausalLM(config.llm_config) |
| elif config.llm_config.architectures[0] == 'InternLM2ForCausalLM': |
| self.language_model = InternLM2ForCausalLM(config.llm_config) |
| else: |
| raise NotImplementedError(f'{config.llm_config.architectures[0]} is not implemented.') |
| |
|
|
| vit_hidden_size = config.vision_config.hidden_size |
| llm_hidden_size = config.llm_config.hidden_size |
|
|
| self.mlp1 = nn.Sequential( |
| nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2), |
| nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size), |
| nn.GELU(), |
| nn.Linear(llm_hidden_size, llm_hidden_size) |
| ) |
| |
| self.img_context_token_id = None |
| self.conv_template = get_conv_template(self.template) |
| self.system_message = self.conv_template.system_message |
| def load_normed_tok_embeddings(self,vocab_size=92553, llm_hidden_size=4096,load_checkboard=False): |
| tok_embeddings = nn.Embedding(vocab_size, llm_hidden_size, padding_idx=2).to_empty(device=torch.device('cuda')).to(torch.bfloat16) |
| tok_embeddings.load_state_dict(torch.load(NORM_TOK_EMBEDDING_PATH, weights_only=True, map_location="cpu")) |
| if load_checkboard: |
| checkboard_norm=torch.load(NORM_PARAMS_PATH) |
| |
| return tok_embeddings,checkboard_norm['weight'] |
| return tok_embeddings |
| |
| def forward( |
| self, |
| pixel_values: torch.FloatTensor, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| image_flags: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple, CausalLMOutputWithPast]: |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| image_flags = image_flags.squeeze(-1) |
| input_embeds = self.language_model.get_input_embeddings()(input_ids) |
|
|
| vit_embeds = self.extract_feature(pixel_values) |
| vit_embeds = vit_embeds[image_flags == 1] |
| vit_batch_size = pixel_values.shape[0] |
|
|
| B, N, C = input_embeds.shape |
| input_embeds = input_embeds.reshape(B * N, C) |
|
|
| if torch.distributed.get_rank() == 0: |
| print(f'dynamic ViT batch size: {vit_batch_size}, images per sample: {vit_batch_size / B}, dynamic token length: {N}') |
|
|
| input_ids = input_ids.reshape(B * N) |
| selected = (input_ids == self.img_context_token_id) |
| try: |
| input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, C) |
| except Exception as e: |
| vit_embeds = vit_embeds.reshape(-1, C) |
| print(f'warning: {e}, input_embeds[selected].shape={input_embeds[selected].shape}, ' |
| f'vit_embeds.shape={vit_embeds.shape}') |
| n_token = selected.sum() |
| input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds[:n_token] |
|
|
| input_embeds = input_embeds.reshape(B, N, C) |
|
|
| outputs = self.language_model( |
| inputs_embeds=input_embeds, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
| logits = outputs.logits |
|
|
| loss = None |
| if labels is not None: |
| |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| |
| loss_fct = CrossEntropyLoss() |
| shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size) |
| shift_labels = shift_labels.view(-1) |
| |
| shift_labels = shift_labels.to(shift_logits.device) |
| loss = loss_fct(shift_logits, shift_labels) |
|
|
| if not return_dict: |
| output = (logits,) + outputs[1:] |
| return (loss,) + output if loss is not None else output |
|
|
| return CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
| def pixel_shuffle(self, x, scale_factor=0.5): |
| n, w, h, c = x.size() |
| |
| x = x.view(n, w, int(h * scale_factor), int(c / scale_factor)) |
| |
| x = x.permute(0, 2, 1, 3).contiguous() |
| |
| x = x.view(n, int(h * scale_factor), int(w * scale_factor), |
| int(c / (scale_factor * scale_factor))) |
| if self.ps_version == 'v1': |
| warnings.warn("In ps_version 'v1', the height and width have not been swapped back, " |
| 'which results in a transposed image.') |
| else: |
| x = x.permute(0, 2, 1, 3).contiguous() |
| return x |
|
|
| def extract_feature(self, pixel_values): |
| if self.select_layer == -1: |
| vit_embeds = self.vision_model( |
| pixel_values=pixel_values, |
| output_hidden_states=False, |
| return_dict=True).last_hidden_state |
| else: |
|
|
| vit_embeds = self.vision_model( |
| pixel_values=pixel_values, |
| output_hidden_states=True, |
| return_dict=True).hidden_states[self.select_layer] |
| vit_embeds = vit_embeds[:, 1:, :] |
|
|
| h = w = int(vit_embeds.shape[1] ** 0.5) |
| vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) |
| vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio) |
| vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) |
|
|
| vit_embeds = self.mlp1(vit_embeds) |
| return vit_embeds |
| |
| @torch.no_grad() |
| def calli_align(self,img_path,detect_model, drop_zero = False, use_hard_vector_quant=False,save_path=None,verbose=False): |
| def dynamic_read(img_path,mode='c'): |
| |
| if isinstance(img_path, str): |
| img = cv2.imread(img_path) |
| |
| if img is None: |
| try: |
| img = Image.open(img_path).convert("RGB") |
| img = np.array(img) |
| except: |
| raise ValueError(f"Image at path {img_path} could not be loaded.") |
| |
| elif isinstance(img_path, Image.Image): |
| img = np.array(img_path) |
| |
| img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) |
| |
| else: |
| raise TypeError(f"Unsupported image type: {type(img_path)}") |
| if mode=='i': |
| img=Image.fromarray(img).convert("RGB") |
| return img |
| import time |
| def iterative_only_boxes(model,jpg_path): |
| |
| image = dynamic_read(jpg_path) |
| |
| image_array = np.array(image) |
| |
| h, w, channels = image.shape |
| boxes=[] |
| |
| |
| color=most_frequent_rgb_fast(image_array) |
| while True: |
| res=model(image_array,verbose=False)[0] |
| |
| to_be_masked=[] |
| for box in res.boxes: |
| xyxy = box.xyxy.squeeze().tolist() |
| x1, y1, x2, y2 = int(xyxy[0]), int(xyxy[1]), int(xyxy[2]), int(xyxy[3]) |
| to_be_masked.append([x1,y1,x2,y2]) |
| boxes.extend(to_be_masked) |
| if len(to_be_masked)>250: |
| image_array=mask_area(image_array,to_be_masked,color) |
| else: |
| break |
| |
| boxes=[[[max(item[0],0),max(item[1],0)],[min(item[2],w),min(item[3],h)]]for item in boxes] |
| |
|
|
| i=0 |
| length=len(boxes) |
| while i<length: |
| j=0 |
| main_box=boxes[i] |
| while j<length: |
| if i==j: |
| j+=1 |
| continue |
| iou=calculate_iou(coord_transform(main_box),coord_transform(boxes[j])) |
| if iou>0.8: |
| rm = boxes[j] |
| boxes.remove(rm) |
| if j<i: |
| i-=1 |
| length-=1 |
| j-=1 |
| j+=1 |
| i+=1 |
|
|
| return boxes |
| def char2col_with_kmeans(jpg_path,boxes, verbose=False): |
| |
| def kmeans_boxes(bounding_boxes): |
| areas = [ (box[1][0] - box[0][0])*(box[1][1] - box[0][1]) for box in bounding_boxes] |
|
|
|
|
| |
| areas = np.array(areas).reshape(-1, 1) |
|
|
| |
| kmeans = KMeans(n_clusters=2, random_state=0).fit(areas) |
|
|
| |
| labels = kmeans.labels_ |
|
|
| |
| group_0 = [] |
| group_1 = [] |
|
|
| for i, label in enumerate(labels): |
| if label == 0: |
| group_0.append(bounding_boxes[i]) |
| else: |
| group_1.append(bounding_boxes[i]) |
| |
| group_0 = sorted(group_0, key = lambda x: (x[1][0]-x[0][0]), reverse=True) |
| group_1 = sorted(group_1, key = lambda x: (x[1][0]-x[0][0]), reverse=True) |
|
|
| if (group_1[0][1][0] - group_1[0][0][0]) > (group_0[0][1][0] - group_0[0][0][0]): |
| g1_hs = np.array([x[1][1]-x[0][1] for x in group_1]).mean() |
| thr1 = 1*( group_1[-1][1][0] - group_1[-1][0][0]) |
| thr2 = 0.8*g1_hs |
| |
| new_0 = [] |
| for ele in group_0: |
| if (ele[1][0] - ele[0][0]) >= thr1 or (ele[1][1] - ele[0][1]) >= thr2 or (areas.min()/(ele[1][0] - ele[0][0])*(ele[1][1] - ele[0][1]) <= 1/5 and areas.mean() / ((ele[1][0] - ele[0][0])*(ele[1][1] - ele[0][1])) <= 1.3): |
| group_1.append(ele) |
| else: |
| new_0.append(ele) |
|
|
| grouped_luokuan = merge_boxes(new_0.copy()) |
| |
| final_ = [] |
| for ele in new_0: |
| if ele in grouped_luokuan: |
| |
| group_1.append(ele) |
| else: |
| final_.append(ele) |
| group_0 = final_ |
| |
| elif (group_0[0][1][0] - group_0[0][0][0]) > (group_1[0][1][0] - group_1[0][0][0]): |
| g0_hs = np.array([x[1][1]-x[0][1] for x in group_0]).mean() |
| thr1 = 1*( group_0[-1][1][0] - group_0[-1][0][0]) |
| thr2 = 0.8*g0_hs |
| |
| new_1 = [] |
| for ele in group_1: |
| if (ele[1][0] - ele[0][0]) >= thr1 or (ele[1][1] - ele[0][1]) >= thr2 or (areas.min()/(ele[1][0] - ele[0][0])*(ele[1][1] - ele[0][1]) <= 1/5 and areas.mean() / ((ele[1][0] - ele[0][0])*(ele[1][1] - ele[0][1])) <=1.3): |
| |
| group_0.append(ele) |
| else: |
| new_1.append(ele) |
| |
| grouped_luokuan = merge_boxes(new_1.copy()) |
| |
| final_ = [] |
| for ele in new_1: |
| if ele in grouped_luokuan: |
| group_0.append(ele) |
| else: |
| final_.append(ele) |
| group_1 = final_ |
| |
| return group_0,group_1 |
|
|
| def toint(lst): |
| if len(lst)==2: |
| return [[int(lst[0][0]),int(lst[0][1])],[int(lst[1][0]),int(lst[1][1])]] |
| else: |
| return [int(lst[0]),int(lst[1]),int(lst[2]),int(lst[3])] |
| img = dynamic_read(jpg_path) |
| h, w, channels = img.shape |
|
|
| normalized_boxes=[[[item[0][0]/w,item[0][1]/h],[item[1][0]/w,item[1][1]/h]] for item in boxes] |
| S=np.array([(item[0][0]-item[1][0])*(item[0][1]-item[1][1]) for item in normalized_boxes]) |
| |
| |
| |
| |
| coef_var=np.std(S)/np.mean(S) |
| boxes2class=None |
| col2class=None |
| |
| if coef_var>0.66 and S.min()/S.mean() <= 1/8: |
| |
| boxes1,boxes2=kmeans_boxes(normalized_boxes) |
| |
| |
| boxes1=[[[item[0][0]*w,item[0][1]*h],[item[1][0]*w,item[1][1]*h]] for item in boxes1] |
| boxes2=[[[item[0][0]*w,item[0][1]*h],[item[1][0]*w,item[1][1]*h]] for item in boxes2] |
| columns1=merge_boxes(boxes1.copy()) |
| columns2=merge_boxes(boxes2.copy()) |
| |
| columns=columns1+columns2 |
| boxes2class={1:[toint(item) for item in boxes1],2:[toint(item) for item in boxes2]} |
| col2class={1:[toint(item) for item in columns1],2:[toint(item) for item in columns2]} |
| |
|
|
| else: |
| columns=merge_boxes(boxes.copy()) |
|
|
|
|
| results={"imageHeight":h,"imageWidth":w,"shapes":[{"points":toint(col)} for col in columns], |
| "boxes2class":boxes2class,"col2class":col2class} |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| return results |
| |
| def sort_boxes(jpg,detector,model,thres=0.8): |
| |
| boxes=iterative_only_boxes(detector,jpg) |
| |
| data=char2col_with_kmeans(jpg,boxes,verbose=False) |
| |
| res=model.predict(data,jpg) |
| final_results=[] |
| for idx,col in res.items(): |
| lst=[] |
| for item in boxes: |
| ratio=calculate_iou(col,[item[0][0],item[0][1],item[1][0],item[1][1]],mini=True) |
| |
| if ratio>=thres: |
| lst.append([item[0][0],item[0][1],item[1][0],item[1][1]]) |
| lst=sorted(lst, key=lambda item: (item[1]+item[3])/2) |
| final_results.extend(lst) |
| |
| return final_results |
| if img_path is None: |
| return None,None |
| |
| st=time.time() |
| boxes=sort_boxes(img_path,detect_model,self.sorter) |
| ed=time.time() |
| if verbose: |
| print(f"YOLO+Orderformer {ed-st:.2f}s") |
| if save_path!=None: |
| frame = dynamic_read(img_path) |
| name=img_path.split("/")[-1] |
| for i,box in enumerate(boxes): |
| |
| xyxy = box |
| x1, y1, x2, y2 = int(xyxy[0]), int(xyxy[1]), int(xyxy[2]), int(xyxy[3]) |
| colo = (255,0,0) |
| cv2.rectangle(frame, (x1, y1), (x2, y2), thickness=2,color=colo,lineType=cv2.LINE_AA) |
| |
| cv2.putText(frame, str(i+1), ((x1+x2)//2, (y1+y2)//2), cv2.FONT_HERSHEY_SIMPLEX, 1.5, colo, thickness=2, lineType=cv2.LINE_AA) |
| print(save_path+"oredered_result_"+name) |
| cv2.imwrite(save_path+"oredered_result_"+name,frame) |
| |
| st=time.time() |
| pixel_values=[] |
| img=np.array(dynamic_read(img_path,mode='i').convert("RGB")) |
| |
| for xyxy in boxes: |
| x1, y1, x2, y2 = int(xyxy[0]), int(xyxy[1]), int(xyxy[2]), int(xyxy[3]) |
| sub_img=Image.fromarray(img[y1:y2,x1:x2]) |
| pixel_values.append(load_image_2(sub_img).to(torch.bfloat16).cuda()) |
| ed1=time.time() |
| results=torch.cat(pixel_values) |
| |
| image_embeddings=self.extract_feature(results) |
| ed2=time.time() |
| output=self.resampler(image_embeddings) |
| ed3=time.time() |
| |
| |
| |
| outs=vq_cos_sim(self.normed_emb,output, use_hard_vector_quant) |
| |
| ed4=time.time() |
| if verbose: |
| print(f"Get pixel values {ed1-st:.2f}s") |
| print(f"extract feat {ed2-ed1:.2f}s") |
| print(f"Resampler forward {ed3-ed2:.2f}") |
| print(f"vq cos sim {ed4-ed3:.2f}s") |
| if use_hard_vector_quant: |
| indices, cos_sim_values = outs |
| |
| thresh = 0.5 |
| else: |
| indices = outs |
|
|
| if use_hard_vector_quant: |
| print("Dynamic vector quantization...") |
| |
| below_mask = (cos_sim_values <= thresh).to(torch.bfloat16).unsqueeze(-1) |
| |
| output = output * (1-below_mask) + self.normed_emb.weight[indices] * below_mask |
| |
| |
| flattened_output = output.view(-1, output.shape[-1]) |
| flattened_indices = indices.view(-1) |
|
|
| if drop_zero: |
| filtered_indices=flattened_indices[flattened_indices!=0] |
| filtered_output=flattened_output[flattened_indices!=0] |
|
|
|
|
| sigma_flat = self.sigma[filtered_indices] |
| mu_flat = self.mu[filtered_indices] |
|
|
| sigma_flat = sigma_flat.expand(-1, filtered_output.shape[-1]) |
| mu_flat = mu_flat.expand(-1, filtered_output.shape[-1]) |
| back_to_origin_flat = filtered_output * sigma_flat + mu_flat |
| |
| else: |
| sigma_flat = self.sigma[flattened_indices] |
| mu_flat = self.mu[flattened_indices] |
| sigma_flat = sigma_flat.expand(-1, flattened_output.shape[-1]) |
| mu_flat = mu_flat.expand(-1, flattened_output.shape[-1]) |
| back_to_origin_flat = flattened_output * sigma_flat + mu_flat |
| |
| |
| return back_to_origin_flat, indices |
|
|
| def find_coordinates(self,text): |
| import re |
|
|
| numbers = re.findall(r'\d+', text) |
|
|
| numbers = [int(num) for num in numbers] |
| return numbers |
| def chat_ocr(self, tokenizer, detect_model,img_path, questions, generation_config, num_patches_list=None, |
| history=None, return_history=False, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>', |
| IMG_CONTEXT_TOKEN='<IMG_CONTEXT>', ALIGNED_TOKEN="[UNUSED_TOKEN_140]",verbose=False, image_counts=None,batch=False, |
| use_p=True, drop_zero=False, hard_vq=False, repetition_penalty=1.5,region_wise=False): |
|
|
|
|
|
|
|
|
| pixel_values = None |
| if img_path is not None: |
| try: |
| if region_wise: |
| img=np.array(Image.open(img_path).convert("RGB")) |
| coord=self.find_coordinates(questions) |
| x1,x2,y1,y2=coord |
| sub_img=Image.fromarray(img[y1:y2,x1:x2]) |
| |
| questions="输出图片中所有文字:" |
| pixel_values=load_image(sub_img).to(torch.bfloat16).to(torch.device("cuda")) |
| else: |
| pixel_values=load_image(img_path).to(torch.bfloat16).to(torch.device("cuda")) |
| except: |
| raise FileNotFoundError |
| if use_p: |
| import time |
| st=time.time() |
| if region_wise: |
| try: |
| out_tokens, indices =self.calli_align(sub_img,detect_model, drop_zero = drop_zero, use_hard_vector_quant=hard_vq,verbose=verbose) |
| except: |
| return "检测失败" |
| else: |
| |
| out_tokens, indices =self.calli_align(img_path,detect_model, drop_zero = drop_zero, use_hard_vector_quant=hard_vq,verbose=verbose) |
| if verbose: |
| print(f"Calli Align: {time.time()-st:.2f}s") |
| |
| |
| if pixel_values is None: |
| question=questions |
|
|
| if pixel_values is not None and '<image>' not in questions: |
| question = '<image>\n' + questions |
| |
| elif history is None and pixel_values is None: |
| question=questions |
| elif '<image>' in questions: |
| question=questions |
|
|
| if history is None and use_p and '[UNUSED_TOKEN_140]' not in question: |
| question =question+'[UNUSED_TOKEN_140]'*out_tokens.shape[0] |
| if num_patches_list is None: |
| num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else [] |
| assert pixel_values is None or len(pixel_values) == sum(num_patches_list) |
|
|
| img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) |
| self.img_context_token_id = img_context_token_id |
|
|
| template = get_conv_template(self.template) |
| template.system_message = self.system_message |
| eos_token_id = tokenizer.convert_tokens_to_ids(template.sep) |
|
|
| history = [] if history is None else history |
| for (old_question, old_answer) in history: |
| template.append_message(template.roles[0], old_question) |
| template.append_message(template.roles[1], old_answer) |
| template.append_message(template.roles[0], question) |
| template.append_message(template.roles[1], None) |
| query = template.get_prompt() |
|
|
| |
|
|
| for num_patches in num_patches_list: |
| image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN |
|
|
| query = query.replace('<image>', image_tokens, 1) |
| |
| model_inputs = tokenizer(query, return_tensors='pt') |
|
|
| input_ids = model_inputs['input_ids'].cuda() |
| |
| attention_mask = model_inputs['attention_mask'].cuda() |
|
|
| generation_config['eos_token_id'] = eos_token_id |
|
|
|
|
| if use_p: |
| generation_output = self.generate_ocr( |
| pixel_values=pixel_values, |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| reference_embeds=out_tokens, |
| repetition_penalty=repetition_penalty, |
| **generation_config |
| ) |
| else: |
| generation_output = self.generate_ocr( |
| pixel_values=pixel_values, |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| repetition_penalty=repetition_penalty, |
| **generation_config |
| ) |
| response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0] |
| response = response.split(template.sep)[0].strip() |
| history.append((question, response)) |
| if return_history: |
| return response, history |
| else: |
| query_to_print = query.replace(IMG_CONTEXT_TOKEN, '') |
| query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '<image>') |
|
|
|
|
| return response |
| |
| |
| def dynamic_chat(self, tokenizer, pixel_values, questions, generation_config, num_patches_list=None, |
| history=None, return_history=False, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>', |
| IMG_CONTEXT_TOKEN='<IMG_CONTEXT>', verbose=False, image_counts=None,batch=False,use_p=True): |
| if use_p: |
| self.num_image_token=3 |
| if batch: |
| assert isinstance(questions,list) and len(questions)>0 and isinstance(questions[0],str) |
| if history is not None or return_history: |
| print('Now multi-turn chat is not supported in batch_chat.') |
| raise NotImplementedError |
|
|
| if image_counts is not None: |
| num_patches_list = image_counts |
| print('Warning: `image_counts` is deprecated. Please use `num_patches_list` instead.') |
|
|
| img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) |
| self.img_context_token_id = img_context_token_id |
|
|
| if verbose and pixel_values is not None: |
| image_bs = pixel_values.shape[0] |
| print(f'dynamic ViT batch size: {image_bs}') |
|
|
| queries = [] |
| for idx, num_patches in enumerate(num_patches_list): |
| question = questions[idx] |
| if pixel_values is not None and '<image>' not in question: |
| question = '<image>\n' + question |
| template = get_conv_template(self.template) |
| template.append_message(template.roles[0], question) |
| template.append_message(template.roles[1], None) |
| query = template.get_prompt() |
|
|
| image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN |
| query = query.replace('<image>', image_tokens, 1) |
| queries.append(query) |
|
|
| |
| tokenizer.padding_side = 'left' |
| model_inputs = tokenizer(queries, return_tensors='pt', padding=True) |
| input_ids = model_inputs['input_ids'].cuda() |
| attention_mask = model_inputs['attention_mask'].cuda() |
| eos_token_id = tokenizer.convert_tokens_to_ids(template.sep) |
| generation_config['eos_token_id'] = eos_token_id |
| if use_p: |
| generation_output = self.generate( |
| pixel_values=pixel_values, |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| **generation_config |
| ) |
| else: |
| |
| generation_output = self.generate_origin( |
| pixel_values=pixel_values, |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| **generation_config |
| ) |
| responses = tokenizer.batch_decode(generation_output, skip_special_tokens=True) |
| responses = [response.split(template.sep)[0].strip() for response in responses] |
| return responses |
| else: |
| assert isinstance(questions,str) |
| if num_patches_list is None: |
| num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else [] |
| assert pixel_values is None or len(pixel_values) == sum(num_patches_list) |
|
|
| img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) |
| self.img_context_token_id = img_context_token_id |
|
|
| template = get_conv_template(self.template) |
| template.system_message = self.system_message |
| eos_token_id = tokenizer.convert_tokens_to_ids(template.sep) |
|
|
| history = [] if history is None else history |
| for (old_question, old_answer) in history: |
| template.append_message(template.roles[0], old_question) |
| template.append_message(template.roles[1], old_answer) |
| template.append_message(template.roles[0], questions) |
| template.append_message(template.roles[1], None) |
| query = template.get_prompt() |
|
|
|
|
| if verbose and pixel_values is not None: |
| image_bs = pixel_values.shape[0] |
| print(f'dynamic ViT batch size: {image_bs}') |
|
|
|
|
| |
| |
|
|
|
|
| query=f"""<|im_start|>system你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。<|im_end|>\n<|im_start|>user{questions}""" |
| query = query+'<image>' |
| for num_patches in num_patches_list: |
| |
| image_tokens = IMG_CONTEXT_TOKEN * self.num_image_token |
| |
| query = query.replace('<image>', image_tokens, 1) |
| |
| query+="<|im_end|>\n<|im_start|>assistant" |
| |
| model_inputs = tokenizer(query, return_tensors='pt') |
| |
|
|
| input_ids = model_inputs['input_ids'].cuda() |
| attention_mask = model_inputs['attention_mask'].cuda() |
|
|
|
|
| generation_config['eos_token_id'] = eos_token_id |
| if use_p: |
| |
| generation_output = self.generate( |
| pixel_values=pixel_values, |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| **generation_config |
| ) |
| else: |
| generation_output = self.generate_origin( |
| pixel_values=pixel_values, |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| **generation_config |
| ) |
| response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0] |
| response = response.split(template.sep)[0].strip() |
| history.append((questions, response)) |
| if return_history: |
| return response, history |
| else: |
| query_to_print = query.replace(IMG_CONTEXT_TOKEN, '') |
| query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '<image>') |
| if verbose: |
| print(query_to_print, response) |
|
|
| return response |
|
|
| def batch_chat(self, tokenizer, pixel_values, questions, generation_config, num_patches_list=None, |
| history=None, return_history=False, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>', |
| IMG_CONTEXT_TOKEN='<IMG_CONTEXT>', verbose=False, image_counts=None): |
| |
| if history is not None or return_history: |
| print('Now multi-turn chat is not supported in batch_chat.') |
| raise NotImplementedError |
|
|
| if image_counts is not None: |
| num_patches_list = image_counts |
| print('Warning: `image_counts` is deprecated. Please use `num_patches_list` instead.') |
|
|
| img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) |
| self.img_context_token_id = img_context_token_id |
|
|
| if verbose and pixel_values is not None: |
| image_bs = pixel_values.shape[0] |
| print(f'dynamic ViT batch size: {image_bs}') |
|
|
| queries = [] |
| for idx, num_patches in enumerate(num_patches_list): |
| question = questions[idx] |
| if pixel_values is not None and '<image>' not in question: |
| question = '<image>\n' + question |
| template = get_conv_template(self.template) |
| template.append_message(template.roles[0], question) |
| template.append_message(template.roles[1], None) |
| query = template.get_prompt() |
|
|
| image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN |
| query = query.replace('<image>', image_tokens, 1) |
| queries.append(query) |
|
|
| |
| tokenizer.padding_side = 'left' |
| model_inputs = tokenizer(queries, return_tensors='pt', padding=True) |
| input_ids = model_inputs['input_ids'].cuda() |
| attention_mask = model_inputs['attention_mask'].cuda() |
| eos_token_id = tokenizer.convert_tokens_to_ids(template.sep) |
| generation_config['eos_token_id'] = eos_token_id |
| generation_output = self.generate_origin( |
| pixel_values=pixel_values, |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| **generation_config |
| ) |
| responses = tokenizer.batch_decode(generation_output, skip_special_tokens=True) |
| responses = [response.split(template.sep)[0].strip() for response in responses] |
| return responses |
|
|
|
|
| |
| def chat(self, tokenizer, pixel_values, question, generation_config, history=None, return_history=False, |
| num_patches_list=None, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>', IMG_CONTEXT_TOKEN='<IMG_CONTEXT>', |
| verbose=False): |
| |
| |
| if history is None and pixel_values is not None and '<image>' not in question: |
| question = '<image>\n' + question |
|
|
| if num_patches_list is None: |
| num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else [] |
| assert pixel_values is None or len(pixel_values) == sum(num_patches_list) |
|
|
| img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) |
| self.img_context_token_id = img_context_token_id |
|
|
| template = get_conv_template(self.template) |
| template.system_message = self.system_message |
| eos_token_id = tokenizer.convert_tokens_to_ids(template.sep) |
|
|
| history = [] if history is None else history |
| for (old_question, old_answer) in history: |
| template.append_message(template.roles[0], old_question) |
| template.append_message(template.roles[1], old_answer) |
| template.append_message(template.roles[0], question) |
| template.append_message(template.roles[1], None) |
| query = template.get_prompt() |
|
|
|
|
| if verbose and pixel_values is not None: |
| image_bs = pixel_values.shape[0] |
|
|
|
|
|
|
| |
| for num_patches in num_patches_list: |
| image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN |
| query = query.replace('<image>', image_tokens, 1) |
| print(num_patches,self.num_image_token) |
| print(pixel_values.shape[0]) |
| |
| model_inputs = tokenizer(query, return_tensors='pt') |
|
|
| input_ids = model_inputs['input_ids'].cuda() |
| attention_mask = model_inputs['attention_mask'].cuda() |
|
|
| generation_config['eos_token_id'] = eos_token_id |
| generation_output = self.generate_origin( |
| pixel_values=pixel_values, |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| **generation_config |
| ) |
| response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0] |
| response = response.split(template.sep)[0].strip() |
| history.append((question, response)) |
| if return_history: |
| return response, history |
| else: |
| query_to_print = query.replace(IMG_CONTEXT_TOKEN, '') |
| query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '<image>') |
| if verbose: |
| print(query_to_print, response) |
|
|
| return response |
| |
| @torch.no_grad() |
| def generate_origin( |
| self, |
| pixel_values: Optional[torch.FloatTensor] = None, |
| input_ids: Optional[torch.FloatTensor] = None, |
| attention_mask: Optional[torch.LongTensor] = None, |
| visual_features: Optional[torch.FloatTensor] = None, |
| generation_config: Optional[GenerationConfig] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| **generate_kwargs, |
| ) -> torch.LongTensor: |
|
|
| assert self.img_context_token_id is not None |
| if pixel_values is not None: |
| if visual_features is not None: |
| vit_embeds = visual_features |
| else: |
| vit_embeds = self.extract_feature(pixel_values) |
| input_embeds = self.language_model.get_input_embeddings()(input_ids) |
|
|
|
|
| B, N, C = input_embeds.shape |
| input_embeds = input_embeds.reshape(B * N, C) |
|
|
| input_ids = input_ids.reshape(B * N) |
| selected = (input_ids == self.img_context_token_id) |
| assert selected.sum() != 0 |
| input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device) |
| print("ID: ",self.img_context_token_id) |
| input_embeds = input_embeds.reshape(B, N, C) |
| else: |
| input_embeds = self.language_model.get_input_embeddings()(input_ids) |
| |
|
|
| outputs = self.language_model.generate( |
| inputs_embeds=input_embeds, |
| attention_mask=attention_mask, |
| generation_config=generation_config, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| use_cache=True, |
| **generate_kwargs, |
| ) |
|
|
| return outputs |
| @torch.no_grad() |
| def generate_ocr( |
| self, |
| pixel_values: Optional[torch.FloatTensor] = None, |
| input_ids: Optional[torch.FloatTensor] = None, |
| attention_mask: Optional[torch.LongTensor] = None, |
| visual_features: Optional[torch.FloatTensor] = None, |
| generation_config: Optional[GenerationConfig] = None, |
| reference_embeds=None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| repetition_penalty=1.5, |
| **generate_kwargs, |
| ) -> torch.LongTensor: |
|
|
| assert self.img_context_token_id is not None |
| if pixel_values is not None: |
| if visual_features is not None: |
| vit_embeds = visual_features |
| else: |
| vit_embeds = self.extract_feature(pixel_values) |
| input_embeds = self.language_model.get_input_embeddings()(input_ids) |
| |
|
|
| B, N, C = input_embeds.shape |
| input_embeds = input_embeds.reshape(B * N, C) |
|
|
| input_ids = input_ids.reshape(B * N) |
| selected = (input_ids == self.img_context_token_id) |
| assert selected.sum() != 0 |
| input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device) |
| |
|
|
| if reference_embeds is not None: |
| selected = (input_ids == 92537) |
| assert selected.sum() != 0 |
| input_embeds[selected] =reference_embeds.reshape(-1, C).to(input_embeds.device) |
| |
|
|
| input_embeds = input_embeds.reshape(B, N, C) |
| else: |
| input_embeds = self.language_model.get_input_embeddings()(input_ids) |
| |
| |
|
|
| outputs = self.language_model.generate( |
| inputs_embeds=input_embeds, |
| attention_mask=attention_mask, |
| generation_config=generation_config, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| use_cache=True, |
| repetition_penalty=repetition_penalty, |
| **generate_kwargs, |
| ) |
|
|
| return outputs |
| @torch.no_grad() |
| def generate( |
| self, |
| pixel_values: Optional[torch.FloatTensor] = None, |
| input_ids: Optional[torch.FloatTensor] = None, |
| attention_mask: Optional[torch.LongTensor] = None, |
| visual_features: Optional[torch.FloatTensor] = None, |
| generation_config: Optional[GenerationConfig] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| **generate_kwargs, |
| ) -> torch.LongTensor: |
|
|
| assert self.img_context_token_id is not None |
| if pixel_values is not None: |
| if visual_features is not None: |
| vit_embeds = visual_features |
| else: |
| |
| vit_embeds = self.extract_feature(pixel_values) |
| |
| input_embeds = self.language_model.get_input_embeddings()(input_ids) |
| |
| vit_embeds = self.resampler(vit_embeds) |
| |
| |
| mu=self.mu_sigma[:,0].reshape((-1,1)) |
| sigma=self.mu_sigma[:,1].reshape((-1,1)) |
|
|
| indices=vq_cos_sim(self.normed_emb,vit_embeds).reshape((-1,)) |
| |
|
|
| vit_embeds=vit_embeds.reshape((-1,vit_embeds.shape[-1]))*sigma[indices][:]+mu[indices][:] |
| |
| B, N, C = input_embeds.shape |
| input_embeds = input_embeds.reshape(B * N, C) |
|
|
| input_ids = input_ids.reshape(B * N) |
| selected = (input_ids == self.img_context_token_id) |
| |
| assert selected.sum() != 0 |
| |
| input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device) |
|
|
| |
|
|
| input_embeds = input_embeds.reshape(B, N, C) |
| else: |
| input_embeds = self.language_model.get_input_embeddings()(input_ids) |
|
|
| outputs = self.language_model.generate( |
| inputs_embeds=input_embeds, |
| attention_mask=attention_mask, |
| generation_config=generation_config, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| use_cache=True, |
| **generate_kwargs, |
| ) |
|
|
| return outputs |
|
|