| | import json |
| | import traceback |
| | from fastapi import FastAPI, HTTPException |
| | from dotenv import load_dotenv |
| | import os |
| | import re |
| | from huggingface_hub import ChatCompletionInputMessage, ChatCompletionInputTool |
| | import litellm |
| | litellm.ssl_verify = False |
| | from litellm.router import Router |
| | from fastapi.middleware.cors import CORSMiddleware |
| | from pydantic import BaseModel, Field |
| | from typing import List, Optional, Literal, Type, Union |
| |
|
| | load_dotenv() |
| |
|
| | app = FastAPI() |
| |
|
| | api_keys = [] |
| |
|
| | for k,v in os.environ.items(): |
| | if re.match(r'^GROQ_\d+$', k): |
| | api_keys.append(v) |
| |
|
| | models_data = { |
| | "allam-2-7b": {"rpm": 30, "rpd": 7000, "tpm": 6000}, |
| | "compound-beta": {"rpm": 15, "rpd": 200, "tpm": 70000}, |
| | "compound-beta-mini": {"rpm": 15, "rpd": 200, "tpm": 70000}, |
| | "deepseek-r1-distill-llama-70b": {"rpm": 30, "rpd": 1000, "tpm": 6000}, |
| | "gemma2-9b-it": {"rpm": 30, "rpd": 14400, "tpm": 15000, "tpd": 500000}, |
| | "llama-3.1-8b-instant": {"rpm": 30, "rpd": 14400, "tpm": 6000, "tpd": 500000}, |
| | "llama-3.3-70b-versatile": {"rpm": 30, "rpd": 1000, "tpm": 12000, "tpd": 100000}, |
| | "llama3-70b-8192": {"rpm": 30, "rpd": 14400, "tpm": 6000, "tpd": 500000}, |
| | "llama3-8b-8192": {"rpm": 30, "rpd": 14400, "tpm": 6000, "tpd": 500000}, |
| | "meta-llama/llama-4-maverick-17b-128e-instruct": {"rpm": 30, "rpd": 1000, "tpm": 6000, "tpd": None}, |
| | "meta-llama/llama-4-scout-17b-16e-instruct": {"rpm": 30, "rpd": 1000, "tpm": 30000, "tpd": None}, |
| | "meta-llama/llama-guard-4-12b": {"rpm": 30, "rpd": 14400, "tpm": 15000, "tpd": 500000}, |
| | "meta-llama/llama-prompt-guard-2-22m": {"rpm": 30, "rpd": 14400, "tpm": 15000, "tpd": None}, |
| | "meta-llama/llama-prompt-guard-2-86m": {"rpm": 30, "rpd": 14400, "tpm": None, "tpd": None}, |
| | } |
| |
|
| | model_list = [ |
| | { |
| | "model_name": f"{model_name}_{key_idx}" if key_idx != 0 else f"{model_name}", |
| | "litellm_params": { |
| | "model": f"groq/{model_name}", |
| | "api_key": api_key |
| | }, |
| | "timeout": 120, |
| | "max_retries": 5 |
| | } |
| | for model_name, config in models_data.items() |
| | for key_idx, api_key in enumerate(api_keys) |
| | ] |
| |
|
| | def generate_fallbacks_per_key(): |
| | fallbacks = [] |
| | excluded_models = {"compound-beta", "compound-beta-mini"} |
| | |
| | for model_name in models_data.keys(): |
| | if model_name in excluded_models: |
| | continue |
| | |
| | |
| | for key_idx in range(len(api_keys)): |
| | current_model = f"{model_name}_{key_idx}" if key_idx != 0 else f"{model_name}" |
| | fallback_versions = [ |
| | f"{model_name}_{other_key_idx}" if other_key_idx != 0 else f"{model_name}" |
| | for other_key_idx in range(len(api_keys)) |
| | if other_key_idx != key_idx |
| | ] |
| | |
| | |
| | fallbacks.append({ |
| | current_model: fallback_versions |
| | }) |
| | |
| | return fallbacks |
| |
|
| | fallbacks = generate_fallbacks_per_key() |
| |
|
| | router = Router( |
| | model_list=model_list, |
| | fallbacks=fallbacks, |
| | num_retries=5, |
| | retry_after=10 |
| | ) |
| |
|
| | app.add_middleware( |
| | CORSMiddleware, |
| | allow_credentials=True, |
| | allow_headers=["*"], |
| | allow_methods=["GET", "POST"], |
| | allow_origins=["*"] |
| | ) |
| |
|
| | class ChatRequest(BaseModel): |
| | models: List[str] |
| | messages: List[ChatCompletionInputMessage] |
| | tools: Optional[List[ChatCompletionInputTool]] = None |
| | temperature: Optional[float] = None |
| | max_tokens: Optional[int] = None |
| | n: Optional[int] = None |
| | stream: Optional[bool] = None |
| | stop: Optional[List[str]] = None |
| |
|
| | def clean_message(msg) -> dict: |
| | """Convertit un message en dictionnaire, gérant différents types d'objets""" |
| | if hasattr(msg, 'model_dump'): |
| | |
| | return {k: v for k, v in msg.model_dump().items() if v is not None} |
| | elif hasattr(msg, '__dict__'): |
| | |
| | return {k: v for k, v in msg.__dict__.items() if v is not None} |
| | elif isinstance(msg, dict): |
| | |
| | return {k: v for k, v in msg.items() if v is not None} |
| | else: |
| | |
| | return dict(msg) |
| |
|
| | @app.get("/") |
| | def main_page(): |
| | return {"status": "ok"} |
| |
|
| | @app.post("/chat") |
| | def chat_with_groq(req: ChatRequest): |
| | models = req.models |
| | if len(models) == 1 and (models[0] == "" or models[0] not in models_data.keys()): |
| | raise HTTPException(400, detail="Empty model field") |
| | messages = [clean_message(m) for m in req.messages] |
| | if len(models) == 1: |
| | try: |
| | resp = router.completion(model=models[0], messages=messages, **req.model_dump(exclude={"models", "messages"}, exclude_defaults=True, exclude_none=True)) |
| | print("Asked to", models[0], ":", messages) |
| | return {"error": False, "content": resp.choices[0].message.content} |
| | except Exception as e: |
| | traceback.print_exception(e) |
| | return {"error": True, "content": "Aucune clé ne fonctionne avec le modèle sélectionné, patientez ...."} |
| | else: |
| | for model in models: |
| | if model not in models_data.keys(): |
| | print(f"Erreur: {model} n'existe pas") |
| | continue |
| | try: |
| | resp = router.completion(model=model, messages=messages, **req.model_dump(exclude={"models", "messages"}, exclude_defaults=True, exclude_none=True)) |
| | print("Asked to", models[0], ":", messages) |
| | return {"error": False, "content": resp.choices[0].message.content} |
| | except Exception as e: |
| | traceback.print_exception(e) |
| | continue |
| | return {"error": True, "content": "Tous les modèles n'ont pas fonctionné avec les différentes clé, patientez ...."} |