| import cv2 |
| import gradio as gr |
| from llm import LLM |
| import llm as llm |
| import argparse |
| import socket |
|
|
| parser = argparse.ArgumentParser(description="Model configuration parameters") |
| parser.add_argument("--hf_model", type=str, default="./InternVL3-2B", |
| help="Path to HuggingFace model") |
| parser.add_argument("--axmodel_path", type=str, default="./InternVL3-2B_axmodel", |
| help="Path to save compiled axmodel of llama model") |
| parser.add_argument("--vit_model", type=str, default=None, |
| help="Path to save compiled axmodel of llama model") |
| args = parser.parse_args() |
|
|
| hf_model_path = args.hf_model |
| axmodel_path = args.axmodel_path |
| vit_axmodel_path = args.vit_model |
|
|
| gllm = LLM(hf_model_path, axmodel_path, vit_axmodel_path) |
|
|
| |
| def get_local_ip(): |
| try: |
| s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) |
| s.connect(("8.8.8.8", 80)) |
| ip = s.getsockname()[0] |
| s.close() |
| return ip |
| except Exception: |
| return "127.0.0.1" |
|
|
| def stop_generation(): |
| gllm.stop_generate() |
|
|
| def respond(prompt, video, image, is_image, video_segments, image_segments_cols, image_segments_rows, history=None): |
| if history is None: |
| history = [] |
| if not prompt.strip(): |
| return history |
| |
| |
| gllm.tag = "video" if not is_image else "image" |
|
|
| history.append((prompt, "")) |
| yield history |
| |
| print(video) |
| print(image) |
| |
| if is_image: |
| img = cv2.imread(image) |
| images_list = [] |
| if image_segments_cols == 1 and image_segments_rows == 1: |
| images_list.append(img) |
| elif image_segments_cols * image_segments_rows > 8: |
| |
| history[-1] = (prompt, history[-1][1] + "image segments cols * image segments rows > 8") |
| yield history |
| return |
| else: |
| height, width, _ = img.shape |
| segment_width = width // image_segments_cols |
| segment_height = height // image_segments_rows |
| for i in range(image_segments_rows): |
| for j in range(image_segments_cols): |
| x1 = j * segment_width |
| y1 = i * segment_height |
| x2 = (j + 1) * segment_width |
| y2 = (i + 1) * segment_height |
| segment = img[y1:y2, x1:x2] |
| images_list.append(segment) |
| else: |
| images_list = llm.load_video_opencv(video, num_segments = video_segments) |
| |
| for msg in gllm.generate(images_list, prompt): |
| print(msg, end="", flush=True) |
| history[-1] = (prompt, history[-1][1] + msg) |
| yield history |
| print("\n\n\n") |
| |
| def chat_interface(): |
| with gr.Blocks() as demo: |
| gr.Markdown("## InternVL3 Chat DEMO \nUpload an image or video and chat with the InternVL3.") |
| with gr.Row(): |
| with gr.Column(scale=1): |
| video = gr.Video(label="Upload Video", format="mp4") |
| video_segments = gr.Slider(minimum=2, maximum=8, step=1, value=4, label="video segments") |
| image = gr.Image(label="Upload Image", type="filepath") |
| image_segments_cols = gr.Slider(minimum=1, maximum=4, step=1, value=1, label="image cols segments") |
| image_segments_rows = gr.Slider(minimum=1, maximum=4, step=1, value=1, label="image rows segments") |
| checkbox = gr.Checkbox(label="Use Image") |
| with gr.Column(scale=3): |
| chatbot = gr.Chatbot(height=650) |
| prompt = gr.Textbox(placeholder="Type your message...", label="Prompt", value="描述一下这组图片") |
| with gr.Row(): |
| btn_chat = gr.Button("Chat", variant="primary") |
| btn_stop = gr.Button("Stop", variant="stop") |
|
|
| btn_stop.click(fn=stop_generation, inputs=None, outputs=None) |
| |
| btn_chat.click( |
| fn=respond, |
| inputs=[prompt, video, image, checkbox, video_segments, image_segments_cols, image_segments_rows, chatbot], |
| outputs=chatbot |
| ) |
| |
| def on_video_uploaded(video): |
| if video is not None: |
| return gr.update(value=False) |
| return gr.update() |
|
|
| def on_image_uploaded(image): |
| if image is not None: |
| return gr.update(value=True) |
| return gr.update() |
|
|
| video.change(fn=on_video_uploaded, inputs=video, outputs=checkbox) |
| image.change(fn=on_image_uploaded, inputs=image, outputs=checkbox) |
| |
|
|
| local_ip = get_local_ip() |
| server_port = 7860 |
| print(f"HTTP 服务地址: http://{local_ip}:{server_port}") |
| demo.launch(server_name=local_ip, server_port=server_port) |
|
|
| if __name__ == "__main__": |
| chat_interface() |
|
|