plm_internvl_ola_code / inference /infer_ola_internvl.py
jjw0126's picture
Upload files
84ff315 verified
import os
os.environ['LOWRES_RESIZE'] = '384x32'
os.environ['HIGHRES_BASE'] = '0x32'
os.environ['VIDEO_RESIZE'] = "0x64"
os.environ['VIDEO_MAXRES'] = "480"
os.environ['VIDEO_MINRES'] = "288"
os.environ['MAXRES'] = '1536'
os.environ['MINRES'] = '0'
os.environ['FORCE_NO_DOWNSAMPLE'] = '1'
os.environ['LOAD_VISION_EARLY'] = '1'
os.environ['PAD2STRIDE'] = '1'
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
import os
import sys
from pathlib import Path
import math
import numpy as np
import torch
import torchvision.transforms as T
from decord import VideoReader, cpu # 暂时注释掉,专注于语音功能测试
from PIL import Image
from torchvision.transforms.functional import InterpolationMode
from transformers import AutoModel, AutoTokenizer
from contextlib import redirect_stdout
import io
import librosa
import whisper
import moviepy as mp
import torch
from transformers import AutoTokenizer, AutoConfig, AutoModel
# pure text
# image + text
# video + text
# audio + text
# video + audio + text
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
import gradio as gr
import torch
import re
from decord import VideoReader, cpu
from PIL import Image
import numpy as np
import transformers
import moviepy as mp
from typing import Dict, Optional, Sequence, List
import librosa
import whisper
from ola.conversation import conv_templates, SeparatorStyle
from ola.model.builder import load_pretrained_model
from ola.datasets.preprocess import tokenizer_image_token, tokenizer_speech_image_token, tokenizer_speech_question_image_token, tokenizer_speech_token
from ola.mm_utils import KeywordsStoppingCriteria, process_anyres_video, process_anyres_highres_image
from ola.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, DEFAULT_SPEECH_TOKEN, SPEECH_TOKEN_INDEX
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str, default='/data1/cxy/plm-v/modeling/internvl3_5-2B')
parser.add_argument('--text', type=str, default="What does the speech say?")
parser.add_argument('--audio_path', type=str, default=None)
parser.add_argument('--image_path', type=str, default=None)
parser.add_argument('--video_path', type=str, default=None)
args = parser.parse_args()
model_path = args.model_path
tokenizer, model, image_processor, _ = load_pretrained_model(model_path,'ola_internvl', None)
model = model.to('cuda').eval()
model = model.bfloat16()
resource_path = "/data1/cxy/plm-v/modeling/example/"
# set the max number of tiles in `max_num`
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
best_ratio_diff = float('inf')
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
return best_ratio
def build_transform(input_size):
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=MEAN, std=STD)
])
return transform
def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set(
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
i * j <= max_num and i * j >= min_num)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size
)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images
def load_image(image_file, input_size=448, max_num=12):
image = Image.open(image_file).convert('RGB')
transform = build_transform(input_size=input_size)
images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
pixel_values = [transform(image) for image in images]
pixel_values = torch.stack(pixel_values)
return pixel_values
pixel_values = load_image(f'{resource_path}image1.jpg', max_num=12).to(torch.bfloat16).cuda()
# breakpoint()
generation_config = dict(max_new_tokens=1024, do_sample=True)
# breakpoint()
question = 'Hello, who are you?'
response, history = model.chat(tokenizer, None, question, generation_config, history=None, return_history=True)
print(f'User: {question}\nAssistant: {response}')
# 多模态推理测试
print("\n" + "="*80)
print("🧪 开始多模态推理测试")
print("="*80)
def test_inference(test_name, question, pixel_values_input=None, speech_input=None, speech_lengths_input=None, num_patches_list=None):
"""统一的推理测试函数"""
print(f"\n{'='*60}")
print(f"🧪 测试: {test_name}")
print(f"📝 问题: {question}")
print(f"{'='*60}")
try:
# 准备参数
chat_kwargs = {
'tokenizer': tokenizer,
'pixel_values': pixel_values_input,
'question': question,
'generation_config': generation_config,
'verbose': True
}
# 如果有视频数据,添加num_patches_list参数
if num_patches_list is not None:
chat_kwargs['num_patches_list'] = num_patches_list
# 如果有speech数据,添加speech参数
if speech_input is not None:
chat_kwargs.update({
'speech': speech_input, # mel 谱图,用于 Whisper
'speech_lengths': speech_lengths_input,
'speech_wav': speech_wavs, # 原始音频波形,用于 BEATs
})
# 执行推理
# breakpoint()
response = model.chat(**chat_kwargs)
print(f"✅ 推理成功!")
print(f"🤖 回复: {response}")
return True, response
except Exception as e:
print(f"❌ 推理失败: {str(e)}")
import traceback
traceback.print_exc()
return False, str(e)
# 测试1: Pure Text (应该正常,使用训练好的InternVL)
success1, response1 = test_inference(
test_name="Pure Text",
question="Hello, who are you? Please introduce yourself briefly.",
pixel_values_input=None,
speech_input=None,
speech_lengths_input=None
)
# 测试2: Text & Image - Visual only (应该正常,使用训练好的InternVL)
# success2, response2 = test_inference(
# test_name="Text & Image (Visual only)",
# question="<image>\nPlease describe this image in detail.",
# pixel_values_input=pixel_values,
# speech_input=None,
# speech_lengths_input=None
# )
print("\n" + "="*60)
print("🔄 准备Speech相关测试 (可能输出乱码,因为speech部分未训练)")
print("="*60)
def get_index(bound, fps, max_frame, first_idx=0, num_segments=32):
if bound:
start, end = bound[0], bound[1]
else:
start, end = -100000, 100000
start_idx = max(first_idx, round(start * fps))
end_idx = min(round(end * fps), max_frame)
seg_size = float(end_idx - start_idx) / num_segments
frame_indices = np.array([
int(start_idx + (seg_size / 2) + np.round(seg_size * idx))
for idx in range(num_segments)
])
return frame_indices
def load_video(video_path, bound=None, input_size=448, max_num=1, num_segments=32):
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
max_frame = len(vr) - 1
fps = float(vr.get_avg_fps())
pixel_values_list, num_patches_list = [], []
transform = build_transform(input_size=input_size)
frame_indices = get_index(bound, fps, max_frame, first_idx=0, num_segments=num_segments)
for frame_index in frame_indices:
img = Image.fromarray(vr[frame_index].asnumpy()).convert('RGB')
img = dynamic_preprocess(img, image_size=input_size, use_thumbnail=True, max_num=max_num)
pixel_values = [transform(tile) for tile in img]
pixel_values = torch.stack(pixel_values)
num_patches_list.append(pixel_values.shape[0])
pixel_values_list.append(pixel_values)
pixel_values = torch.cat(pixel_values_list)
return pixel_values, num_patches_list
def load_audio(audio_file_name):
"""
加载音频文件,使用Ola风格的mel谱图预处理
这与原始的Ola load_audio函数保持一致
"""
speech_wav, samplerate = librosa.load(audio_file_name, sr=16000)
if len(speech_wav.shape) > 1:
speech_wav = speech_wav[:, 0]
speech_wav = speech_wav.astype(np.float32)
CHUNK_LIM = 480000
SAMPLE_RATE = 16000
speechs = []
speech_wavs = []
if len(speech_wav) <= CHUNK_LIM:
speech = whisper.pad_or_trim(speech_wav)
speech_wav_chunk = whisper.pad_or_trim(speech_wav)
speechs.append(speech)
speech_wavs.append(torch.from_numpy(speech_wav_chunk).unsqueeze(0))
else:
for i in range(0, len(speech_wav), CHUNK_LIM):
chunk = speech_wav[i : i + CHUNK_LIM]
if len(chunk) < CHUNK_LIM:
chunk = whisper.pad_or_trim(chunk)
speechs.append(chunk)
speech_wavs.append(torch.from_numpy(chunk).unsqueeze(0))
# 生成mel谱图
mels = []
for chunk in speechs:
chunk = whisper.log_mel_spectrogram(chunk, n_mels=128).permute(1, 0).unsqueeze(0)
mels.append(chunk)
mels = torch.cat(mels, dim=0)
speech_wavs = torch.cat(speech_wavs, dim=0)
if mels.shape[0] > 25:
mels = mels[:25]
speech_wavs = speech_wavs[:25]
speech_length = torch.LongTensor([mels.shape[1]] * mels.shape[0])
speech_chunks = torch.LongTensor([mels.shape[0]])
return mels, speech_length, speech_chunks, speech_wavs
def extract_audio(videos_file_path):
my_clip = mp.VideoFileClip(videos_file_path)
return my_clip.audio
# 加载视频数据用于视频测试
print("\n📥 加载视频数据...")
try:
video_path = f'{resource_path}red-panda.mp4'
if os.path.exists(video_path):
video_pixel_values, video_num_patches_list = load_video(video_path, num_segments=8, max_num=1)
video_pixel_values = video_pixel_values.to(torch.bfloat16).cuda()
video_loaded = True
print(f"✅ 视频加载成功:")
print(f" - 视频帧数: {len(video_num_patches_list)}")
print(f" - 视频像素值形状: {video_pixel_values.shape}")
print(f" - 每帧patch数: {video_num_patches_list}")
else:
print(f"⚠️ 视频文件不存在: {video_path}")
video_loaded = False
video_pixel_values = None
video_num_patches_list = None
except Exception as e:
print(f"❌ 视频加载失败: {e}")
video_loaded = False
video_pixel_values = None
video_num_patches_list = None
audio_path = f'/data1/cxy/dataset/english.mp3'
# 加载音频数据用于后续测试
print("\n📥 加载音频数据...")
try:
# 加载音频文件 - 使用Ola风格的mel谱图预处理
mels, speech_lengths, speech_chunks, speech_wavs = load_audio(audio_path)
print(f"✅ 音频加载成功:")
print(f" - mel谱图形状: {mels.shape}")
print(f" - 音频长度: {speech_lengths}")
print(f" - 音频块数: {speech_chunks}")
print(f" - 原始音频波形形状: {speech_wavs.shape}")
# 将音频数据转换为适当的格式并移到GPU
mels = mels.to(torch.bfloat16).cuda()
speech_lengths = speech_lengths.cuda()
speech_chunks = speech_chunks.cuda()
speech_wavs = speech_wavs.cuda()
audio_loaded = True
except Exception as e:
print(f"❌ 音频加载失败: {e}")
audio_loaded = False
mels = None
speech_lengths = None
# 测试3: Audio only (可能乱码,speech部分未训练)
if audio_loaded:
success3, response3 = test_inference(
test_name="Audio only (预期乱码)",
question="<speech>\nPlease transcribe and summarize what you heard in the audio.",
pixel_values_input=None,
speech_input=mels,
speech_lengths_input=speech_lengths
)
else:
print("⚠️ 跳过Audio only测试 (音频加载失败)")
success3 = False
# # 测试4: Audio + Image (可能乱码,speech部分未训练)
# if audio_loaded:
# success4, response4 = test_inference(
# test_name="Audio + Image (预期乱码)",
# question="<image>\nUser's question in speech: <speech>\n",
# pixel_values_input=pixel_values,
# speech_input=mels,
# speech_lengths_input=speech_lengths
# )
# else:
# print("⚠️ 跳过Audio + Image测试 (音频加载失败)")
# success4 = False
# 测试5: Video + Text (应该正常,使用训练好的InternVL)
# if video_loaded:
# # 构建视频帧前缀
# video_prefix = ''.join([f'Frame{i+1}: <image>\n' for i in range(len(video_num_patches_list))])
# video_question = video_prefix + 'What is the red panda doing in this video? Please describe the actions and movements you observe.'
# success5, response5 = test_inference(
# test_name="Video + Text",
# question=video_question,
# pixel_values_input=video_pixel_values,
# speech_input=None,
# speech_lengths_input=None,
# num_patches_list=video_num_patches_list
# )
# else:
# print("⚠️ 跳过Video + Text测试 (视频加载失败)")
# success5 = False
# 测试5: Video + Audio (可能乱码,speech部分未训练)
# if audio_loaded:
# success5, response5 = test_inference(
# test_name="Video + Audio (预期乱码)",
# question="<speech><image>\nDescribe what you hear and see in this content.",
# pixel_values_input=pixel_values,
# speech_input=mels,
# speech_lengths_input=speech_lengths
# )
# else:
# print("⚠️ 跳过Video + Audio测试 (音频加载失败)")
# success5 = False
# 测试总结
print("\n" + "="*80)
print("📊 多模态推理测试总结")
print("="*80)
test_results = [
("Pure Text", success1, "PASS", "应该正常 (训练好的InternVL)"),
# ("Text & Image", success2, "PASS", "应该正常 (训练好的InternVL)"),
# ("Video + Text", success5 if video_loaded else False, "PASS", "应该正常 (训练好的InternVL)"),
("Audio only", success3 if audio_loaded else False, "GARBLED", "可能乱码 (speech未训练)"),
# ("Audio + Image", success4 if audio_loaded else False, "GARBLED", "可能乱码 (speech未训练)"),
]
for test_name, success, expected, note in test_results:
status = "✅ PASS" if success else "❌ FAIL"
print(f"{status} {test_name:<15} (预期: {expected:<8}) - {note}")
passed = sum(1 for _, success, _, _ in test_results if success)
total = len(test_results)
print(f"\n📈 测试统计: {passed}/{total} 通过")
if passed >= 2: # 至少pure text、text&image、video+text中的2个应该通过
print("🎉 基础功能正常,Speech集成架构成功!")
print("💡 Speech相关测试如果输出乱码是正常的,因为speech部分还未训练")
if passed >= 3:
print("🌟 所有基础模态测试都通过了!")
else:
print("⚠️ 基础功能可能存在问题,需要进一步检查")
print("\n=== 多模态推理测试完成 ===")