| |
| import os |
| import uvicorn |
| import uuid |
| import time |
| import json |
| from datetime import datetime |
| from typing import Optional, List, Union, Literal |
|
|
| from fastapi import FastAPI, HTTPException, Depends, status |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.responses import StreamingResponse |
| from pydantic import BaseModel, Field |
| from llama_cpp import Llama |
|
|
| |
| VALID_API_KEYS = { |
| "sk-adminkey02", |
| "sk-testkey123", |
| "sk-userkey456", |
| "sk-demokey789" |
| } |
| MODEL_PATH = "capybarahermes-2.5-mistral-7b.Q5_K_M.gguf" |
| MODEL_NAME = "capybarahermes-2.5-mistral-7b" |
|
|
| |
| llm = None |
| security = HTTPBearer() |
|
|
| |
|
|
| class Message(BaseModel): |
| role: Literal["system", "user", "assistant"] |
| content: str |
|
|
| class ChatCompletionRequest(BaseModel): |
| model: str = MODEL_NAME |
| messages: List[Message] |
| max_tokens: Optional[int] = 512 |
| temperature: Optional[float] = 0.7 |
| top_p: Optional[float] = 0.9 |
| n: Optional[int] = 1 |
| stream: Optional[bool] = False |
| stop: Optional[Union[str, List[str]]] = None |
|
|
| class ChatCompletionChoice(BaseModel): |
| index: int |
| message: Message |
| finish_reason: Optional[Literal["stop", "length"]] = None |
|
|
| class Usage(BaseModel): |
| prompt_tokens: int |
| completion_tokens: int |
| total_tokens: int |
|
|
| class ChatCompletionResponse(BaseModel): |
| id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4().hex}") |
| object: str = "chat.completion" |
| created: int = Field(default_factory=lambda: int(time.time())) |
| model: str = MODEL_NAME |
| choices: List[ChatCompletionChoice] |
| usage: Usage |
|
|
| class ModelData(BaseModel): |
| id: str |
| object: str = "model" |
| created: int = Field(default_factory=lambda: int(time.time())) |
| owned_by: str = "user" |
|
|
| class ModelsResponse(BaseModel): |
| object: str = "list" |
| data: List[ModelData] |
|
|
| |
|
|
| app = FastAPI( |
| title="CapybaraHermes OpenAI-Compatible API", |
| description=f"An OpenAI-compatible API for the {MODEL_NAME} model.", |
| version="1.0.0", |
| docs_url="/v1/docs", |
| redoc_url="/v1/redoc" |
| ) |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
|
|
| def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)): |
| if credentials.credentials not in VALID_API_KEYS: |
| raise HTTPException( |
| status_code=status.HTTP_401_UNAUTHORIZED, |
| detail="Invalid or missing API key" |
| ) |
| return credentials.credentials |
|
|
| |
|
|
| @app.on_event("startup") |
| def load_model(): |
| global llm |
| if not os.path.exists(MODEL_PATH): |
| raise FileNotFoundError(f"Model file not found at {MODEL_PATH}") |
| |
| print("π Loading GGUF model...") |
| llm = Llama( |
| model_path=MODEL_PATH, |
| n_ctx=4096, |
| n_threads=2, |
| n_batch=512, |
| verbose=False, |
| use_mlock=True, |
| n_gpu_layers=0, |
| ) |
| print("β
Model loaded successfully!") |
|
|
| |
|
|
| def format_messages(messages: List[Message]) -> str: |
| """Formats messages for the ChatML format expected by the model.""" |
| formatted = "" |
| for message in messages: |
| formatted += f"<|im_start|>{message.role}\n{message.content}<|im_end|>\n" |
| formatted += "<|im_start|>assistant\n" |
| return formatted |
|
|
| def count_tokens_rough(text: str) -> int: |
| """A rough approximation of token counting.""" |
| return len(text.split()) |
|
|
| |
|
|
| @app.get("/v1/health") |
| async def health_check(): |
| """Health check endpoint.""" |
| return {"status": "healthy", "model_loaded": llm is not None} |
|
|
| @app.get("/v1/models", response_model=ModelsResponse) |
| async def list_models(api_key: str = Depends(verify_api_key)): |
| """Lists the available models.""" |
| return ModelsResponse(data=[ModelData(id=MODEL_NAME)]) |
|
|
| @app.post("/v1/chat/completions") |
| async def create_chat_completion( |
| request: ChatCompletionRequest, |
| api_key: str = Depends(verify_api_key) |
| ): |
| """Creates a model response for the given chat conversation.""" |
| if llm is None: |
| raise HTTPException(status_code=503, detail="Model is not loaded yet") |
|
|
| prompt = format_messages(request.messages) |
| |
| |
| if request.stream: |
| async def stream_generator(): |
| completion_id = f"chatcmpl-{uuid.uuid4().hex}" |
| created_time = int(time.time()) |
| |
| stream = llm( |
| prompt, |
| max_tokens=request.max_tokens, |
| temperature=request.temperature, |
| top_p=request.top_p, |
| stop=["<|im_end|>", "<|im_start|>"] + (request.stop or []), |
| stream=True, |
| echo=False |
| ) |
| |
| for output in stream: |
| if 'choices' in output and len(output['choices']) > 0: |
| delta_content = output['choices'][0].get('text', '') |
| chunk = { |
| "id": completion_id, |
| "object": "chat.completion.chunk", |
| "created": created_time, |
| "model": MODEL_NAME, |
| "choices": [{"index": 0, "delta": {"content": delta_content}, "finish_reason": None}] |
| } |
| yield f"data: {json.dumps(chunk)}\n\n" |
|
|
| |
| final_chunk = { |
| "id": completion_id, |
| "object": "chat.completion.chunk", |
| "created": created_time, |
| "model": MODEL_NAME, |
| "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}] |
| } |
| yield f"data: {json.dumps(final_chunk)}\n\n" |
| yield "data: [DONE]\n\n" |
| |
| return StreamingResponse(stream_generator(), media_type="text/event-stream") |
|
|
| |
| else: |
| response = llm( |
| prompt, |
| max_tokens=request.max_tokens, |
| temperature=request.temperature, |
| top_p=request.top_p, |
| stop=["<|im_end|>", "<|im_start|>"] + (request.stop or []), |
| echo=False |
| ) |
| |
| response_text = response['choices'][0]['text'].strip() |
| |
| prompt_tokens = count_tokens_rough(prompt) |
| completion_tokens = count_tokens_rough(response_text) |
| |
| return ChatCompletionResponse( |
| model=MODEL_NAME, |
| choices=[ |
| ChatCompletionChoice( |
| index=0, |
| message=Message(role="assistant", content=response_text), |
| finish_reason="stop" |
| ) |
| ], |
| usage=Usage( |
| prompt_tokens=prompt_tokens, |
| completion_tokens=completion_tokens, |
| total_tokens=prompt_tokens + completion_tokens |
| ) |
| ) |
|
|
| if __name__ == "__main__": |
| uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|