| import os |
| import json |
| import asyncio |
| import numpy as np |
| import onnxruntime as ort |
| import tiktoken |
| from fastapi import FastAPI, Request |
| from fastapi.responses import StreamingResponse |
| from fastapi.middleware.cors import CORSMiddleware |
|
|
| |
| app = FastAPI() |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| TOKENIZER = tiktoken.get_encoding("gpt2") |
| MODEL_PATH = "SmaLLMPro_350M_int8.onnx" |
| VOCAB_SIZE = 50304 |
|
|
| |
| |
| |
| options = ort.SessionOptions() |
| options.intra_op_num_threads = 2 |
| options.inter_op_num_threads = 2 |
| options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL |
| options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL |
|
|
| |
| print(f"🚀 Lade Modell {MODEL_PATH} mit CPU-Optimierung...") |
| session = ort.InferenceSession( |
| MODEL_PATH, |
| sess_options=options, |
| providers=['CPUExecutionProvider'] |
| ) |
|
|
| def fast_top_k_sample(logits, k=25, temp=0.7, penalty=1.2, history=None): |
| """Hochoptimiertes Sampling mit NumPy""" |
| |
| if history is not None and penalty != 1.0: |
| |
| unique_history = np.unique(history) |
| |
| valid_indices = unique_history[unique_history < len(logits)] |
| logits[valid_indices] /= penalty |
|
|
| |
| logits = logits / max(temp, 1e-6) |
|
|
| |
| |
| top_k_idx = np.argpartition(logits, -k)[-k:] |
| top_k_logits = logits[top_k_idx] |
| |
| |
| shifted_logits = top_k_logits - np.max(top_k_logits) |
| exp_logits = np.exp(shifted_logits) |
| probs = exp_logits / np.sum(exp_logits) |
| |
| |
| choice = np.random.choice(top_k_idx, p=probs) |
| return int(choice) |
|
|
| @app.post("/chat") |
| async def chat(request: Request): |
| try: |
| data = await request.json() |
| user_prompt = data.get('prompt', '') |
| max_len = int(data.get('maxLen', 100)) |
| temp = float(data.get('temp', 0.7)) |
| top_k = int(data.get('topK', 25)) |
| repetition_penalty = float(data.get('penalty', 1.2)) |
|
|
| |
| full_prompt = f"Instruction:\n{user_prompt}\n\nResponse:\n" |
| tokens = TOKENIZER.encode(full_prompt) |
|
|
| async def generate(): |
| nonlocal tokens |
| |
| history = np.array(tokens, dtype=np.int32) |
|
|
| for _ in range(max_len): |
| |
| ctx = tokens[-1024:] |
| input_array = np.zeros((1, 1024), dtype=np.int64) |
| input_array[0, -len(ctx):] = ctx |
| |
| |
| |
| outputs = session.run(None, {'input': input_array}) |
| |
| |
| logits = outputs[0][0, -1, :VOCAB_SIZE].astype(np.float32) |
| |
| |
| next_token = fast_top_k_sample( |
| logits, |
| k=top_k, |
| temp=temp, |
| penalty=repetition_penalty, |
| history=history |
| ) |
| |
| if next_token == 50256: |
| break |
| |
| |
| tokens.append(next_token) |
| history = np.append(history, next_token) |
| |
| |
| yield f"data: {json.dumps({'token': TOKENIZER.decode([next_token])})}\n\n" |
| |
| |
| await asyncio.sleep(0.01) |
|
|
| return StreamingResponse(generate(), media_type="text/event-stream") |
|
|
| except Exception as e: |
| print(f"Error: {e}") |
| return {"error": str(e)} |
|
|
| @app.get("/") |
| async def health(): |
| return {"status": "SmaLLMPro INT8 Engine Online", "threads": options.intra_op_num_threads} |
|
|
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run(app, host="0.0.0.0", port=7860) |