| | import os |
| | from typing import List, Literal, Optional |
| |
|
| | import torch |
| | from fastapi import FastAPI |
| | from pydantic import BaseModel, Field |
| | from transformers import AutoModelForCausalLM, AutoTokenizer |
| |
|
| | |
| | |
| | |
| | MODEL_NAME = os.getenv("MODEL_NAME", "MBZUAI-Paris/Nile-Chat-12B") |
| |
|
| | MAX_MAX_NEW_TOKENS = 2048 |
| | DEFAULT_MAX_NEW_TOKENS = 1024 |
| | MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "2024")) |
| |
|
| | app = FastAPI(title="Nile-Chat-12B FastAPI") |
| |
|
| | tokenizer = None |
| | model = None |
| |
|
| |
|
| | |
| | |
| | |
| | Role = Literal["system", "user", "assistant"] |
| |
|
| | class ChatMessage(BaseModel): |
| | role: Role |
| | content: str |
| |
|
| | class GenerateRequest(BaseModel): |
| | |
| | |
| | messages: List[ChatMessage] = Field(..., description="Conversation messages in OpenAI-like format") |
| |
|
| | max_new_tokens: int = Field(DEFAULT_MAX_NEW_TOKENS, ge=1, le=MAX_MAX_NEW_TOKENS) |
| | do_sample: bool = True |
| | temperature: float = Field(0.6, ge=0.0, le=4.0) |
| | top_p: float = Field(0.9, ge=0.05, le=1.0) |
| | top_k: int = Field(50, ge=1, le=1000) |
| | repetition_penalty: float = Field(1.1, ge=1.0, le=2.0) |
| |
|
| |
|
| | class GenerateResponse(BaseModel): |
| | response: str |
| | trimmed: bool = False |
| | model: str = MODEL_NAME |
| |
|
| |
|
| | |
| | |
| | |
| | @app.on_event("startup") |
| | def startup_event(): |
| | global tokenizer, model |
| |
|
| | tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
| |
|
| | |
| | dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 |
| |
|
| | model = AutoModelForCausalLM.from_pretrained( |
| | MODEL_NAME, |
| | device_map="auto", |
| | torch_dtype=dtype, |
| | ) |
| | model.eval() |
| |
|
| | print("Model ready") |
| |
|
| |
|
| | @app.get("/health") |
| | def health(): |
| | return {"status": "ok", "model": MODEL_NAME} |
| |
|
| |
|
| | |
| | |
| | |
| | @app.post("/generate", response_model=GenerateResponse) |
| | def generate(req: GenerateRequest): |
| | global tokenizer, model |
| |
|
| | if not req.messages: |
| | return GenerateResponse(response="Error: messages is empty", trimmed=False) |
| |
|
| | |
| | conversation = [m.model_dump() for m in req.messages] |
| |
|
| | |
| | input_ids = tokenizer.apply_chat_template( |
| | conversation, |
| | add_generation_prompt=True, |
| | return_tensors="pt" |
| | ) |
| |
|
| | trimmed = False |
| | if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: |
| | input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] |
| | trimmed = True |
| |
|
| | input_ids = input_ids.to(model.device) |
| |
|
| | |
| | last_user = next((m.content for m in reversed(req.messages) if m.role == "user"), "") |
| | print("\n=== Incoming Request ===") |
| | print("MODEL:", MODEL_NAME) |
| | print("LAST USER:", last_user) |
| | print("trimmed_input:", trimmed) |
| | print("input_tokens:", int(input_ids.shape[1])) |
| |
|
| | |
| | with torch.no_grad(): |
| | out = model.generate( |
| | input_ids=input_ids, |
| | max_new_tokens=req.max_new_tokens, |
| | do_sample=req.do_sample, |
| | top_p=req.top_p, |
| | top_k=req.top_k, |
| | temperature=req.temperature, |
| | num_beams=1, |
| | repetition_penalty=req.repetition_penalty, |
| | ) |
| |
|
| | |
| | new_tokens = out[0, input_ids.shape[-1]:] |
| | response_text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip() |
| |
|
| | print("\n=== Model Response ===") |
| | print(response_text) |
| | print("======================\n") |
| |
|
| | return GenerateResponse(response=response_text, trimmed=trimmed, model=MODEL_NAME) |
| |
|