Spaces:
Build error
Build error
Claude
refactor: Sprint 3b — remove dead code, non-functional provider, unused fields
11f019c unverified | """ | |
| Tests du provider Mistral AI (MistralProvider). | |
| Stratégie : | |
| - Pas d'appel réseau réel : SDK mocké via sys.modules. | |
| - is_configured() : vérifié via variables d'env ET import mock. | |
| - list_models() : mock de client.models.list() → comportement dynamique | |
| et fallback statique quand l'API échoue. | |
| - generate_content() : bifurcation vision / texte seul. | |
| """ | |
| # 1. stdlib | |
| import sys | |
| import types as _types | |
| # 2. third-party | |
| import pytest | |
| # 3. local | |
| from app.schemas.model_config import ProviderType | |
| from app.services.ai.provider_mistral import ( | |
| MistralProvider, | |
| _MISTRAL_FALLBACK_MODELS, | |
| _model_supports_vision, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Helpers — faux SDK Mistral | |
| # --------------------------------------------------------------------------- | |
| class _FakeCaps: | |
| """Capabilities d'un modèle Mistral (SDK v1.x).""" | |
| def __init__(self, vision: bool = False): | |
| self.vision = vision | |
| class _FakeModel: | |
| def __init__(self, id_: str, vision: bool = False, display_name: str | None = None): | |
| self.id = id_ | |
| self.display_name = display_name or id_ | |
| self.capabilities = _FakeCaps(vision=vision) | |
| class _FakeModelsListResponse: | |
| def __init__(self, models: list[_FakeModel]): | |
| self.data = models | |
| class _FakeModelsAPI: | |
| def __init__(self, models: list[_FakeModel]): | |
| self._models = models | |
| def list(self) -> _FakeModelsListResponse: | |
| return _FakeModelsListResponse(self._models) | |
| class _FakeMessage: | |
| content = "Voici le JSON de la page." | |
| class _FakeChoice: | |
| message = _FakeMessage() | |
| class _FakeChatResponse: | |
| choices = [_FakeChoice()] | |
| class _FakeChat: | |
| def complete(self, *, model, messages): | |
| return _FakeChatResponse() | |
| def _make_fake_mistralai(models: list[_FakeModel] | None = None) -> _types.ModuleType: | |
| """Crée un faux module mistralai avec Mistral class et modèles mockés.""" | |
| fake = _types.ModuleType("mistralai") | |
| chat = _FakeChat() | |
| models_api = _FakeModelsAPI(models or []) | |
| class _FakeMistral: | |
| def __init__(self, api_key): | |
| self.chat = chat | |
| self.models = models_api | |
| fake.Mistral = _FakeMistral | |
| return fake | |
| # --------------------------------------------------------------------------- | |
| # _model_supports_vision() — détection dynamique via l'API | |
| # --------------------------------------------------------------------------- | |
| def test_vision_detection_without_model_obj_returns_false(): | |
| """Sans objet modèle (pas de capabilities), retourne False par sécurité.""" | |
| assert _model_supports_vision("pixtral-large-latest") is False | |
| assert _model_supports_vision("mistral-small-latest") is False | |
| assert _model_supports_vision("codestral-latest") is False | |
| def test_vision_detection_uses_capabilities_from_api(): | |
| """La source de vérité est capabilities.vision retourné par l'API Mistral.""" | |
| m_vision = _FakeModel("some-model", vision=True) | |
| m_text = _FakeModel("some-model", vision=False) | |
| assert _model_supports_vision("some-model", m_vision) is True | |
| assert _model_supports_vision("some-model", m_text) is False | |
| def test_vision_detection_capabilities_false_on_any_model(): | |
| """capabilities.vision=False → pas de vision, quel que soit le nom.""" | |
| m = _FakeModel("pixtral-test", vision=False) | |
| assert _model_supports_vision("pixtral-test", m) is False | |
| def test_vision_detection_capabilities_true_on_any_model(): | |
| """capabilities.vision=True → vision activée, quel que soit le nom.""" | |
| m = _FakeModel("mistral-small-latest", vision=True) | |
| assert _model_supports_vision("mistral-small-latest", m) is True | |
| # --------------------------------------------------------------------------- | |
| # is_configured() | |
| # --------------------------------------------------------------------------- | |
| def test_is_configured_true(monkeypatch): | |
| """Clé présente + mistralai v1.x importable → True.""" | |
| monkeypatch.setenv("MISTRAL_API_KEY", "test-key-abc") | |
| fake = _make_fake_mistralai() | |
| monkeypatch.setitem(sys.modules, "mistralai", fake) | |
| assert MistralProvider().is_configured() is True | |
| def test_is_configured_false_no_key(monkeypatch): | |
| """Pas de clé → False, même si mistralai est installé.""" | |
| monkeypatch.delenv("MISTRAL_API_KEY", raising=False) | |
| fake = _make_fake_mistralai() | |
| monkeypatch.setitem(sys.modules, "mistralai", fake) | |
| assert MistralProvider().is_configured() is False | |
| def test_is_configured_false_empty_key(monkeypatch): | |
| monkeypatch.setenv("MISTRAL_API_KEY", "") | |
| fake = _make_fake_mistralai() | |
| monkeypatch.setitem(sys.modules, "mistralai", fake) | |
| assert MistralProvider().is_configured() is False | |
| def test_is_configured_false_v0x_installed(monkeypatch): | |
| """Clé présente mais mistralai v0.x (pas de classe Mistral) → False.""" | |
| monkeypatch.setenv("MISTRAL_API_KEY", "test-key") | |
| fake_v0 = _types.ModuleType("mistralai") | |
| # Pas d'attribut Mistral → from mistralai import Mistral lèvera ImportError | |
| monkeypatch.setitem(sys.modules, "mistralai", fake_v0) | |
| assert MistralProvider().is_configured() is False | |
| def test_is_configured_false_mistralai_not_installed(monkeypatch): | |
| """Clé présente mais mistralai pas du tout installé → False.""" | |
| monkeypatch.setenv("MISTRAL_API_KEY", "test-key") | |
| # Supprimer mistralai du chemin d'import | |
| monkeypatch.setitem(sys.modules, "mistralai", None) # type: ignore[arg-type] | |
| assert MistralProvider().is_configured() is False | |
| # --------------------------------------------------------------------------- | |
| # provider_type | |
| # --------------------------------------------------------------------------- | |
| def test_provider_type(monkeypatch): | |
| fake = _make_fake_mistralai() | |
| monkeypatch.setitem(sys.modules, "mistralai", fake) | |
| assert MistralProvider().provider_type == ProviderType.MISTRAL | |
| # --------------------------------------------------------------------------- | |
| # list_models() — comportement dynamique | |
| # --------------------------------------------------------------------------- | |
| def _setup_list_models(monkeypatch, models: list[_FakeModel]) -> None: | |
| """Configure le monkeypatch pour list_models().""" | |
| monkeypatch.setenv("MISTRAL_API_KEY", "test-key") | |
| fake = _make_fake_mistralai(models) | |
| monkeypatch.setitem(sys.modules, "mistralai", fake) | |
| def test_list_models_dynamic_returns_all_non_embed(monkeypatch): | |
| """list_models() retourne tous les modèles sauf embeddings/modération. | |
| mistral-ocr-latest est toujours ajouté s'il n'est pas dans la liste dynamique.""" | |
| _setup_list_models(monkeypatch, [ | |
| _FakeModel("pixtral-large-latest", vision=True), | |
| _FakeModel("pixtral-12b-2409", vision=True), | |
| _FakeModel("mistral-large-latest", vision=False), | |
| _FakeModel("mistral-embed", vision=False), # exclut | |
| _FakeModel("mistral-moderation", vision=False), # exclut | |
| ]) | |
| models = MistralProvider().list_models() | |
| ids = {m.model_id for m in models} | |
| assert "pixtral-large-latest" in ids | |
| assert "pixtral-12b-2409" in ids | |
| assert "mistral-large-latest" in ids | |
| assert "mistral-ocr-latest" in ids # ajouté automatiquement | |
| assert "mistral-embed" not in ids | |
| assert "mistral-moderation" not in ids | |
| assert len(models) == 4 # 3 filtres + OCR ajouté | |
| def test_list_models_vision_flag_from_capabilities(monkeypatch): | |
| """supports_vision reflète capabilities.vision du SDK.""" | |
| _setup_list_models(monkeypatch, [ | |
| _FakeModel("pixtral-large-latest", vision=True), | |
| _FakeModel("mistral-large-latest", vision=False), | |
| ]) | |
| models = MistralProvider().list_models() | |
| by_id = {m.model_id: m for m in models} | |
| assert by_id["pixtral-large-latest"].supports_vision is True | |
| assert by_id["mistral-large-latest"].supports_vision is False | |
| def test_list_models_all_mistral_provider(monkeypatch): | |
| _setup_list_models(monkeypatch, [ | |
| _FakeModel("pixtral-large-latest", vision=True), | |
| _FakeModel("mistral-large-latest", vision=False), | |
| ]) | |
| models = MistralProvider().list_models() | |
| assert all(m.provider == ProviderType.MISTRAL for m in models) | |
| def test_list_models_fallback_when_api_fails(monkeypatch): | |
| """Si client.models.list() lève une exception, retourne la liste statique.""" | |
| monkeypatch.setenv("MISTRAL_API_KEY", "test-key") | |
| fake = _types.ModuleType("mistralai") | |
| class _FailingModels: | |
| def list(self): | |
| raise RuntimeError("API timeout") | |
| class _FakeMistral: | |
| def __init__(self, api_key): | |
| self.models = _FailingModels() | |
| fake.Mistral = _FakeMistral | |
| monkeypatch.setitem(sys.modules, "mistralai", fake) | |
| models = MistralProvider().list_models() | |
| # Fallback = _MISTRAL_FALLBACK_MODELS = Pixtral Large + 12B + mistral-ocr-latest | |
| assert len(models) == 3 | |
| ids = {m.model_id for m in models} | |
| assert "pixtral-large-latest" in ids | |
| assert "pixtral-12b-2409" in ids | |
| assert "mistral-ocr-latest" in ids | |
| def test_list_models_raises_if_not_configured(monkeypatch): | |
| monkeypatch.delenv("MISTRAL_API_KEY", raising=False) | |
| with pytest.raises(RuntimeError, match="MISTRAL_API_KEY"): | |
| MistralProvider().list_models() | |
| # --------------------------------------------------------------------------- | |
| # generate_content() — bifurcation vision / texte | |
| # --------------------------------------------------------------------------- | |
| def test_generate_content_vision_model_returns_text(monkeypatch): | |
| """Modèle vision : envoie l'image et retourne la réponse.""" | |
| monkeypatch.setenv("MISTRAL_API_KEY", "test-key") | |
| fake = _make_fake_mistralai() | |
| monkeypatch.setitem(sys.modules, "mistralai", fake) | |
| result = MistralProvider().generate_content( | |
| b"fake-jpeg", "Analyse ce folio.", "pixtral-large-latest", | |
| supports_vision=True, | |
| ) | |
| assert result == "Voici le JSON de la page." | |
| def test_generate_content_text_model_returns_text(monkeypatch): | |
| """Modèle texte (supports_vision=False) : envoie seulement le prompt.""" | |
| monkeypatch.setenv("MISTRAL_API_KEY", "test-key") | |
| fake = _make_fake_mistralai() | |
| monkeypatch.setitem(sys.modules, "mistralai", fake) | |
| result = MistralProvider().generate_content( | |
| b"fake-jpeg", "Analyse ce folio.", "mistral-large-latest", | |
| supports_vision=False, | |
| ) | |
| assert result == "Voici le JSON de la page." | |
| def test_generate_content_vision_sends_image_url(monkeypatch): | |
| """Modèle vision : le message content contient image_url + text.""" | |
| monkeypatch.setenv("MISTRAL_API_KEY", "test-key") | |
| captured: list[dict] = [] | |
| class _CapturingChat: | |
| def complete(self, *, model, messages): | |
| captured.extend(messages) | |
| return _FakeChatResponse() | |
| class _FakeMistral: | |
| def __init__(self, api_key): | |
| self.chat = _CapturingChat() | |
| self.models = _FakeModelsAPI([]) | |
| fake = _types.ModuleType("mistralai") | |
| fake.Mistral = _FakeMistral | |
| monkeypatch.setitem(sys.modules, "mistralai", fake) | |
| MistralProvider().generate_content(b"jpeg", "prompt", "pixtral-large-latest", supports_vision=True) | |
| assert len(captured) == 1 | |
| content = captured[0]["content"] | |
| assert isinstance(content, list) | |
| types_sent = {item["type"] for item in content} | |
| assert "image_url" in types_sent | |
| assert "text" in types_sent | |
| def test_generate_content_text_sends_string_content(monkeypatch): | |
| """Modèle texte (supports_vision=False) : le message content est une chaîne (pas d'image).""" | |
| monkeypatch.setenv("MISTRAL_API_KEY", "test-key") | |
| captured: list[dict] = [] | |
| class _CapturingChat: | |
| def complete(self, *, model, messages): | |
| captured.extend(messages) | |
| return _FakeChatResponse() | |
| class _FakeMistral: | |
| def __init__(self, api_key): | |
| self.chat = _CapturingChat() | |
| self.models = _FakeModelsAPI([]) | |
| fake = _types.ModuleType("mistralai") | |
| fake.Mistral = _FakeMistral | |
| monkeypatch.setitem(sys.modules, "mistralai", fake) | |
| MistralProvider().generate_content(b"jpeg", "mon prompt", "mistral-large-latest", supports_vision=False) | |
| assert len(captured) == 1 | |
| assert captured[0]["content"] == "mon prompt" | |
| def test_generate_content_raises_if_not_configured(monkeypatch): | |
| monkeypatch.delenv("MISTRAL_API_KEY", raising=False) | |
| with pytest.raises(RuntimeError, match="MISTRAL_API_KEY"): | |
| MistralProvider().generate_content(b"img", "prompt", "pixtral-large-latest") | |
| def test_generate_content_raises_if_v0x_installed(monkeypatch): | |
| """Si mistralai v0.x est installé (is_configured() → False), RuntimeError clair.""" | |
| monkeypatch.setenv("MISTRAL_API_KEY", "test-key") | |
| fake_v0 = _types.ModuleType("mistralai") | |
| monkeypatch.setitem(sys.modules, "mistralai", fake_v0) | |
| with pytest.raises(RuntimeError, match="mistralai>=1.0"): | |
| MistralProvider().generate_content(b"img", "prompt", "pixtral-large-latest") | |
| def test_generate_content_empty_response(monkeypatch): | |
| """Si choices est vide, retourne une chaîne vide sans exception.""" | |
| monkeypatch.setenv("MISTRAL_API_KEY", "test-key") | |
| class _EmptyChat: | |
| def complete(self, *, model, messages): | |
| class _EmptyResp: | |
| choices = [] | |
| return _EmptyResp() | |
| class _FakeMistral: | |
| def __init__(self, api_key): | |
| self.chat = _EmptyChat() | |
| self.models = _FakeModelsAPI([]) | |
| fake = _types.ModuleType("mistralai") | |
| fake.Mistral = _FakeMistral | |
| monkeypatch.setitem(sys.modules, "mistralai", fake) | |
| result = MistralProvider().generate_content(b"img", "prompt", "pixtral-large-latest") | |
| assert result == "" | |
| # --------------------------------------------------------------------------- | |
| # generate_content() — chemin OCR dédié (mistral-ocr-latest) | |
| # --------------------------------------------------------------------------- | |
| def test_generate_content_ocr_uses_ocr_endpoint(monkeypatch): | |
| """mistral-ocr-latest utilise client.ocr.process(), pas client.chat.complete().""" | |
| monkeypatch.setenv("MISTRAL_API_KEY", "test-key") | |
| ocr_calls: list[dict] = [] | |
| chat_calls: list = [] | |
| class _FakeOCRPage: | |
| markdown = "Explicit liber primus..." | |
| class _FakeOCRResponse: | |
| pages = [_FakeOCRPage(), _FakeOCRPage()] | |
| class _FakeOCR: | |
| def process(self, *, model, document): | |
| ocr_calls.append({"model": model, "document": document}) | |
| return _FakeOCRResponse() | |
| class _FakeChat: | |
| def complete(self, *, model, messages): | |
| chat_calls.append(messages) | |
| class _FakeMistral: | |
| def __init__(self, api_key): | |
| self.ocr = _FakeOCR() | |
| self.chat = _FakeChat() | |
| self.models = _FakeModelsAPI([]) | |
| fake = _types.ModuleType("mistralai") | |
| fake.Mistral = _FakeMistral | |
| monkeypatch.setitem(sys.modules, "mistralai", fake) | |
| result = MistralProvider().generate_content(b"jpeg", "prompt", "mistral-ocr-latest") | |
| # OCR endpoint appelé, pas chat | |
| assert len(ocr_calls) == 1 | |
| assert len(chat_calls) == 0 | |
| assert ocr_calls[0]["model"] == "mistral-ocr-latest" | |
| # Document doit être image_url avec data URI | |
| doc = ocr_calls[0]["document"] | |
| assert doc["type"] == "image_url" | |
| assert doc["image_url"]["url"].startswith("data:image/jpeg;base64,") | |
| # Résultat = pages concaténées | |
| assert "Explicit liber primus..." in result | |
| def test_generate_content_ocr_concatenates_pages(monkeypatch): | |
| """OCR multi-pages : les markdowns sont concaténés par double saut de ligne.""" | |
| monkeypatch.setenv("MISTRAL_API_KEY", "test-key") | |
| class _Page: | |
| def __init__(self, md): | |
| self.markdown = md | |
| class _FakeOCRResponse: | |
| pages = [_Page("Page 1 texte"), _Page("Page 2 texte")] | |
| class _FakeOCR: | |
| def process(self, **kwargs): | |
| return _FakeOCRResponse() | |
| class _FakeMistral: | |
| def __init__(self, api_key): | |
| self.ocr = _FakeOCR() | |
| self.models = _FakeModelsAPI([]) | |
| fake = _types.ModuleType("mistralai") | |
| fake.Mistral = _FakeMistral | |
| monkeypatch.setitem(sys.modules, "mistralai", fake) | |
| result = MistralProvider().generate_content(b"jpeg", "prompt", "mistral-ocr-latest") | |
| assert "Page 1 texte" in result | |
| assert "Page 2 texte" in result | |
| assert "\n\n" in result | |
| def test_generate_content_ocr_model_not_called_for_vision(monkeypatch): | |
| """Un modèle Pixtral NE passe PAS par l'endpoint OCR.""" | |
| monkeypatch.setenv("MISTRAL_API_KEY", "test-key") | |
| ocr_called = [] | |
| class _FakeOCR: | |
| def process(self, **kwargs): | |
| ocr_called.append(True) | |
| class _FakeMistral: | |
| def __init__(self, api_key): | |
| self.ocr = _FakeOCR() | |
| self.chat = type("C", (), {"complete": lambda self, **k: _FakeChatResponse()})() | |
| self.models = _FakeModelsAPI([]) | |
| fake = _types.ModuleType("mistralai") | |
| fake.Mistral = _FakeMistral | |
| monkeypatch.setitem(sys.modules, "mistralai", fake) | |
| MistralProvider().generate_content(b"jpeg", "prompt", "pixtral-large-latest", supports_vision=True) | |
| assert len(ocr_called) == 0 | |
| def test_generate_content_ocr_model_detected_by_id(monkeypatch): | |
| """Tout modèle contenant 'ocr' dans l'ID utilise l'endpoint OCR.""" | |
| monkeypatch.setenv("MISTRAL_API_KEY", "test-key") | |
| ocr_called = [] | |
| class _FakeOCR: | |
| def process(self, **kwargs): | |
| ocr_called.append(True) | |
| class R: | |
| pages = [] | |
| return R() | |
| class _FakeMistral: | |
| def __init__(self, api_key): | |
| self.ocr = _FakeOCR() | |
| self.models = _FakeModelsAPI([]) | |
| fake = _types.ModuleType("mistralai") | |
| fake.Mistral = _FakeMistral | |
| monkeypatch.setitem(sys.modules, "mistralai", fake) | |
| MistralProvider().generate_content(b"jpeg", "prompt", "mistral-ocr-latest") | |
| assert len(ocr_called) == 1 | |