| import os |
| import re |
| from typing import Union |
|
|
| import pytest |
| from _pytest.monkeypatch import MonkeyPatch |
| from requests import Response |
| from requests.sessions import Session |
| from xinference_client.client.restful.restful_client import ( |
| Client, |
| RESTfulChatModelHandle, |
| RESTfulEmbeddingModelHandle, |
| RESTfulGenerateModelHandle, |
| RESTfulRerankModelHandle, |
| ) |
| from xinference_client.types import Embedding, EmbeddingData, EmbeddingUsage |
|
|
|
|
| class MockXinferenceClass: |
| def get_chat_model(self: Client, model_uid: str) -> Union[RESTfulGenerateModelHandle, RESTfulChatModelHandle]: |
| if not re.match(r"https?:\/\/[^\s\/$.?#].[^\s]*$", self.base_url): |
| raise RuntimeError("404 Not Found") |
|
|
| if "generate" == model_uid: |
| return RESTfulGenerateModelHandle(model_uid, base_url=self.base_url, auth_headers={}) |
| if "chat" == model_uid: |
| return RESTfulChatModelHandle(model_uid, base_url=self.base_url, auth_headers={}) |
| if "embedding" == model_uid: |
| return RESTfulEmbeddingModelHandle(model_uid, base_url=self.base_url, auth_headers={}) |
| if "rerank" == model_uid: |
| return RESTfulRerankModelHandle(model_uid, base_url=self.base_url, auth_headers={}) |
| raise RuntimeError("404 Not Found") |
|
|
| def get(self: Session, url: str, **kwargs): |
| response = Response() |
| if "v1/models/" in url: |
| |
| model_uid = url.split("/")[-1] or "" |
| if not re.match( |
| r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", model_uid |
| ) and model_uid not in {"generate", "chat", "embedding", "rerank"}: |
| response.status_code = 404 |
| response._content = b"{}" |
| return response |
|
|
| |
| if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", url): |
| response.status_code = 404 |
| response._content = b"{}" |
| return response |
|
|
| if model_uid in {"generate", "chat"}: |
| response.status_code = 200 |
| response._content = b"""{ |
| "model_type": "LLM", |
| "address": "127.0.0.1:43877", |
| "accelerators": [ |
| "0", |
| "1" |
| ], |
| "model_name": "chatglm3-6b", |
| "model_lang": [ |
| "en" |
| ], |
| "model_ability": [ |
| "generate", |
| "chat" |
| ], |
| "model_description": "latest chatglm3", |
| "model_format": "pytorch", |
| "model_size_in_billions": 7, |
| "quantization": "none", |
| "model_hub": "huggingface", |
| "revision": null, |
| "context_length": 2048, |
| "replica": 1 |
| }""" |
| return response |
|
|
| elif model_uid == "embedding": |
| response.status_code = 200 |
| response._content = b"""{ |
| "model_type": "embedding", |
| "address": "127.0.0.1:43877", |
| "accelerators": [ |
| "0", |
| "1" |
| ], |
| "model_name": "bge", |
| "model_lang": [ |
| "en" |
| ], |
| "revision": null, |
| "max_tokens": 512 |
| }""" |
| return response |
|
|
| elif "v1/cluster/auth" in url: |
| response.status_code = 200 |
| response._content = b"""{ |
| "auth": true |
| }""" |
| return response |
|
|
| def _check_cluster_authenticated(self): |
| self._cluster_authed = True |
|
|
| def rerank( |
| self: RESTfulRerankModelHandle, documents: list[str], query: str, top_n: int, return_documents: bool |
| ) -> dict: |
| |
| if ( |
| not re.match(r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", self._model_uid) |
| and self._model_uid != "rerank" |
| ): |
| raise RuntimeError("404 Not Found") |
|
|
| if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._base_url): |
| raise RuntimeError("404 Not Found") |
|
|
| if top_n is None: |
| top_n = 1 |
|
|
| return { |
| "results": [ |
| {"index": i, "document": doc, "relevance_score": 0.9} for i, doc in enumerate(documents[:top_n]) |
| ] |
| } |
|
|
| def create_embedding(self: RESTfulGenerateModelHandle, input: Union[str, list[str]], **kwargs) -> dict: |
| |
| if ( |
| not re.match(r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", self._model_uid) |
| and self._model_uid != "embedding" |
| ): |
| raise RuntimeError("404 Not Found") |
|
|
| if isinstance(input, str): |
| input = [input] |
| ipt_len = len(input) |
|
|
| embedding = Embedding( |
| object="list", |
| model=self._model_uid, |
| data=[ |
| EmbeddingData(index=i, object="embedding", embedding=[1919.810 for _ in range(768)]) |
| for i in range(ipt_len) |
| ], |
| usage=EmbeddingUsage(prompt_tokens=ipt_len, total_tokens=ipt_len), |
| ) |
|
|
| return embedding |
|
|
|
|
| MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" |
|
|
|
|
| @pytest.fixture |
| def setup_xinference_mock(request, monkeypatch: MonkeyPatch): |
| if MOCK: |
| monkeypatch.setattr(Client, "get_model", MockXinferenceClass.get_chat_model) |
| monkeypatch.setattr(Client, "_check_cluster_authenticated", MockXinferenceClass._check_cluster_authenticated) |
| monkeypatch.setattr(Session, "get", MockXinferenceClass.get) |
| monkeypatch.setattr(RESTfulEmbeddingModelHandle, "create_embedding", MockXinferenceClass.create_embedding) |
| monkeypatch.setattr(RESTfulRerankModelHandle, "rerank", MockXinferenceClass.rerank) |
| yield |
|
|
| if MOCK: |
| monkeypatch.undo() |
|
|