| from __future__ import annotations |
|
|
| import base64 |
| import json |
| import logging |
| import os |
| import uuid |
| from io import BytesIO |
|
|
| import requests |
| from PIL import Image |
|
|
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| from ..index_func import * |
| from ..presets import * |
| from ..utils import * |
| from .base_model import BaseLLMModel |
| from .. import shared |
|
|
| imp_model = AutoModelForCausalLM.from_pretrained( |
| "MILVLG/imp-v1-3b", |
| torch_dtype=torch.float32, |
| device_map="auto", |
| trust_remote_code=True) |
| imp_tokenizer = AutoTokenizer.from_pretrained("MILVLG/imp-v1-3b", trust_remote_code=True) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| class XMChat(BaseLLMModel): |
| def __init__(self, api_key, user_name="", common_model=None, common_tokenizer=None): |
| super().__init__(model_name="xmchat", user=user_name) |
| self.api_key = api_key |
| self.image_flag = False |
| self.session_id = None |
| self.reset() |
| self.image_bytes = None |
| self.image_path = None |
| self.xm_history = [] |
| self.url = "https://xmbot.net/web" |
| self.last_conv_id = None |
| self.max_generation_token = 100 |
| |
| self.common_model = common_model |
| self.common_tokenizer = common_tokenizer |
| self.system_prompt = "A chat between a curious user and an artificial intelligence assistant. This artificial intelligence assistant is a chatbot named as Imp, and developed by MILVLG team. Imp gives helpful, detailed, and polite answers to the user's questions." |
|
|
| def reset(self, remain_system_prompt=False): |
| logging.info("Reseting...") |
| self.session_id = str(uuid.uuid4()) |
| self.last_conv_id = None |
| self.image_bytes = None |
| self.image_flag = False |
| return super().reset() |
|
|
| def image_to_base64(self, image_path): |
| |
| img = Image.open(image_path) |
|
|
| |
| width, height = img.size |
|
|
| |
| max_dimension = 2048 |
| scale_ratio = min(max_dimension / width, max_dimension / height) |
|
|
| if scale_ratio < 1: |
| |
| new_width = int(width * scale_ratio) |
| new_height = int(height * scale_ratio) |
| img = img.resize((new_width, new_height), Image.LANCZOS) |
|
|
| |
| buffer = BytesIO() |
| if img.mode == "RGBA": |
| img = img.convert("RGB") |
| img.save(buffer, format='JPEG') |
| binary_image = buffer.getvalue() |
|
|
| |
| base64_image = base64.b64encode(binary_image).decode('utf-8') |
|
|
| return base64_image |
|
|
| def try_read_image(self, filepath): |
| def is_image_file(filepath): |
| |
| valid_image_extensions = [ |
| ".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff"] |
| file_extension = os.path.splitext(filepath)[1].lower() |
| return file_extension in valid_image_extensions |
|
|
| if is_image_file(filepath): |
| logging.info(f"读取图片文件: {filepath}") |
| self.image_bytes = Image.open(filepath) |
| self.image_path = filepath |
| self.image_flag = True |
| else: |
| self.image_bytes = None |
| self.image_path = None |
| |
|
|
| def like(self): |
| if self.last_conv_id is None: |
| return "点赞失败,你还没发送过消息" |
| data = { |
| "uuid": self.last_conv_id, |
| "appraise": "good" |
| } |
| requests.post(self.url, json=data) |
| return "👍点赞成功,感谢反馈~" |
|
|
| def dislike(self): |
| if self.last_conv_id is None: |
| return "点踩失败,你还没发送过消息" |
| data = { |
| "uuid": self.last_conv_id, |
| "appraise": "bad" |
| } |
| requests.post(self.url, json=data) |
| return "👎点踩成功,感谢反馈~" |
|
|
| def prepare_inputs(self, real_inputs, use_websearch, files, reply_language, chatbot): |
| fake_inputs = real_inputs |
| display_append = "" |
| limited_context = False |
| return limited_context, fake_inputs, display_append, real_inputs, chatbot |
|
|
| def handle_file_upload(self, files, chatbot, language): |
| """if the model accepts multi modal input, implement this function""" |
| if files: |
| for file in files: |
| if file.name: |
| logging.info(f"尝试读取图像: {file.name}") |
| self.try_read_image(file.name) |
| if self.image_path is not None: |
| chatbot = chatbot + [((self.image_path,), None)] |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| return None, chatbot, None |
|
|
| def _get_imp_style_inputs(self): |
| context = """ |
| A chat between a curious user and an artificial intelligence assistant. This artificial intelligence assistant is a multimodal chatbot named as Imp, and developed by MILVLG team from Hangzhou Dianzi University. Imp gives helpful, detailed, and polite answers to the user's questions. |
| """.strip() |
| for ii, i in enumerate(self.history): |
| if i["role"] == "user": |
| if self.image_flag and ii == len(self.history) - 1: |
| context = context.replace('<image>\n', '') |
| i["content"] = '<image>\n' + i["content"] |
| self.image_flag = False |
| context += ' USER: ' + i["content"].strip() |
| else: |
| context += ' ASSISTANT: ' + i["content"].strip() + '</s>' |
| context += ' ASSISTANT:' |
| return context |
|
|
| def get_answer_at_once(self): |
| |
| |
| global imp_model, imp_tokenizer |
| prompt = self._get_imp_style_inputs() |
| logging.info(prompt) |
| |
| |
| input_ids = imp_tokenizer(prompt, return_tensors='pt').input_ids |
| image_tensor = None |
| if '<image>' in prompt: |
| |
| image_tensor = imp_model.image_preprocess(self.image_bytes) |
| output_ids = imp_model.generate( |
| input_ids, |
| max_new_tokens=3000, |
| images=image_tensor, |
| |
| do_sample=True if self.temperature > 0 else False, |
| |
| top_p=self.top_p, |
| temperature=self.temperature, |
| |
| num_return_sequences=1, |
| use_cache=True)[0] |
| response = imp_tokenizer.decode(output_ids[input_ids.shape[1]:], skip_special_tokens=True).strip() |
| return response, len(response) |
|
|