| import abc |
| import json |
| import os |
| import re |
| import time |
| import urllib |
| from queue import Queue |
| from threading import Thread |
| from typing import List, Optional |
| from urllib.parse import quote, urlparse, urlunparse |
|
|
| from langchain.chains.base import Chain |
|
|
| from app_modules.llm_loader import LLMLoader, TextIteratorStreamer |
| from app_modules.utils import remove_extra_spaces |
|
|
| chat_history_enabled = os.getenv("CHAT_HISTORY_ENABLED", "false").lower() == "true" |
|
|
|
|
| def get_system_prompt_and_user_message(orca=False): |
| |
| system_prompt = ( |
| "You are Orca, an AI language model created by Microsoft. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior." |
| if orca |
| else "You are a chatbot having a conversation with a human." |
| ) |
|
|
| user_message = "{input}" |
|
|
| if chat_history_enabled: |
| user_message = "Chat History:\n\n{history} \n\n" + user_message |
| system_prompt += " Read the chat history to get context." |
|
|
| return system_prompt, user_message |
|
|
|
|
| class LLMInference(metaclass=abc.ABCMeta): |
| def __init__(self, llm_loader): |
| self.llm_loader = llm_loader |
| self.chain = None |
| self.pattern = re.compile(r"\s*<.+>$") |
|
|
| @abc.abstractmethod |
| def create_chain(self) -> Chain: |
| pass |
|
|
| def get_chain(self) -> Chain: |
| if self.chain is None: |
| self.chain = self.create_chain() |
|
|
| return self.chain |
|
|
| def reset(self) -> None: |
| self.chain = None |
|
|
| def _process_inputs(self, inputs): |
| return inputs |
|
|
| def _normalize_result(self, result): |
| |
| if isinstance(result, list): |
| result = result[0] |
|
|
| key = "text" if "text" in result else "generated_text" |
| if key in result: |
| result["answer"] = result[key] |
| del result[key] |
|
|
| result["answer"] = self.pattern.sub("", result["answer"]) |
| return result |
|
|
| def _process_results(self, results): |
| if isinstance(results, list): |
| return [self._normalize_result(result) for result in results] |
|
|
| return self._normalize_result(results) |
|
|
| def _run_batch(self, chain, inputs): |
| if self.llm_loader.llm_model_type == "huggingface": |
| results = self.llm_loader.llm.pipeline(inputs) |
| else: |
| results = chain.batch(inputs) |
|
|
| return results |
|
|
| def run_chain(self, chain, inputs, callbacks: Optional[List] = []): |
| inputs = self._process_inputs(inputs) |
|
|
| |
| if isinstance(inputs, list): |
| results = self._run_batch(chain, inputs) |
| else: |
| results = chain.invoke(inputs, {"callbacks": callbacks}) |
|
|
| return self._process_results(results) |
|
|
| def call_chain( |
| self, |
| inputs, |
| streaming_handler, |
| q: Queue = None, |
| testing: bool = False, |
| ): |
| print(json.dumps(inputs, indent=4)) |
| if self.llm_loader.huggingfaceStreamingEnabled(): |
| self.llm_loader.lock.acquire() |
|
|
| try: |
| if self.llm_loader.huggingfaceStreamingEnabled(): |
| self.llm_loader.streamer.reset(q) |
|
|
| chain = self.get_chain() |
| result = ( |
| self._run_chain_with_streaming_handler( |
| chain, inputs, streaming_handler, testing |
| ) |
| if streaming_handler is not None |
| else self.run_chain(chain, inputs) |
| ) |
|
|
| if "answer" in result: |
| result["answer"] = remove_extra_spaces(result["answer"]) |
|
|
| return result |
| finally: |
| if self.llm_loader.huggingfaceStreamingEnabled(): |
| self.llm_loader.lock.release() |
|
|
| def _execute_chain(self, chain, inputs, q, sh): |
| q.put(self.run_chain(chain, inputs, callbacks=[sh])) |
|
|
| def _run_chain_with_streaming_handler( |
| self, chain, inputs, streaming_handler, testing |
| ): |
| que = Queue() |
|
|
| t = Thread( |
| target=self._execute_chain, |
| args=(chain, inputs, que, streaming_handler), |
| ) |
| t.start() |
|
|
| if self.llm_loader.huggingfaceStreamingEnabled(): |
| count = ( |
| 2 |
| if "chat_history" in inputs and len(inputs.get("chat_history")) > 0 |
| else 1 |
| ) |
|
|
| while count > 0: |
| try: |
| for token in self.llm_loader.streamer: |
| if not testing: |
| streaming_handler.on_llm_new_token(token) |
|
|
| self.llm_loader.streamer.reset() |
| count -= 1 |
| except Exception: |
| if not testing: |
| print("nothing generated yet - retry in 0.5s") |
| time.sleep(0.5) |
|
|
| t.join() |
| return que.get() |
|
|
| def apply_chat_template(self, user_message): |
| result = ( |
| [] |
| if re.search(r"gemma|mistral", self.llm_loader.model_name, re.IGNORECASE) |
| else [ |
| { |
| "role": "system", |
| "content": get_system_prompt_and_user_message()[0], |
| } |
| ] |
| ) |
| result.append( |
| { |
| "role": "user", |
| "content": user_message, |
| } |
| ) |
| return result |
|
|