| """
|
| API Server for Mamba Swarm
|
| FastAPI-based server for serving the distributed Mamba language model
|
| """
|
|
|
| from fastapi import FastAPI, HTTPException, BackgroundTasks, Depends
|
| from fastapi.middleware.cors import CORSMiddleware
|
| from fastapi.responses import StreamingResponse
|
| from pydantic import BaseModel, Field
|
| from typing import List, Optional, Dict, Any, AsyncGenerator
|
| import asyncio
|
| import json
|
| import time
|
| import logging
|
| import torch
|
| from contextlib import asynccontextmanager
|
| import uvicorn
|
|
|
|
|
| from system.mambaSwarm import SwarmEngine
|
| from system.inference import InferenceEngine
|
| from routing.router import Router
|
| from training.trainer import setup_logging
|
|
|
|
|
| class GenerationRequest(BaseModel):
|
| prompt: str = Field(..., description="Input text prompt")
|
| max_length: int = Field(default=100, ge=1, le=2048, description="Maximum generation length")
|
| temperature: float = Field(default=0.7, ge=0.1, le=2.0, description="Sampling temperature")
|
| top_p: float = Field(default=0.9, ge=0.1, le=1.0, description="Top-p sampling")
|
| top_k: int = Field(default=50, ge=1, le=100, description="Top-k sampling")
|
| repetition_penalty: float = Field(default=1.1, ge=1.0, le=2.0, description="Repetition penalty")
|
| stream: bool = Field(default=False, description="Enable streaming response")
|
| domain: Optional[str] = Field(default=None, description="Specific domain for routing")
|
|
|
| class GenerationResponse(BaseModel):
|
| generated_text: str
|
| prompt: str
|
| generation_time: float
|
| tokens_generated: int
|
| model_info: Dict[str, Any]
|
|
|
| class StreamingToken(BaseModel):
|
| token: str
|
| is_final: bool = False
|
| metadata: Optional[Dict[str, Any]] = None
|
|
|
| class HealthResponse(BaseModel):
|
| status: str
|
| swarm_status: Dict[str, Any]
|
| system_info: Dict[str, Any]
|
| timestamp: float
|
|
|
| class ModelInfo(BaseModel):
|
| total_parameters: int
|
| active_encoders: int
|
| total_encoders: int
|
| memory_usage: Dict[str, float]
|
| device_info: List[str]
|
|
|
|
|
| swarm_engine: Optional[SwarmEngine] = None
|
| inference_engine: Optional[InferenceEngine] = None
|
|
|
| @asynccontextmanager
|
| async def lifespan(app: FastAPI):
|
| """Manage application lifespan"""
|
| global swarm_engine, inference_engine
|
|
|
|
|
| logging.info("Initializing Mamba Swarm API Server...")
|
|
|
| try:
|
|
|
| swarm_engine = SwarmEngine()
|
| await asyncio.get_event_loop().run_in_executor(None, swarm_engine.initialize)
|
|
|
|
|
| inference_engine = InferenceEngine(swarm_engine)
|
|
|
| logging.info("Mamba Swarm API Server initialized successfully")
|
|
|
| except Exception as e:
|
| logging.error(f"Failed to initialize swarm: {e}")
|
| raise
|
|
|
| yield
|
|
|
|
|
| logging.info("Shutting down Mamba Swarm API Server...")
|
| if swarm_engine:
|
| swarm_engine.shutdown()
|
|
|
|
|
| app = FastAPI(
|
| title="Mamba Swarm API",
|
| description="Distributed Mamba Language Model API with 100 encoder units",
|
| version="1.0.0",
|
| lifespan=lifespan
|
| )
|
|
|
|
|
| app.add_middleware(
|
| CORSMiddleware,
|
| allow_origins=["*"],
|
| allow_credentials=True,
|
| allow_methods=["*"],
|
| allow_headers=["*"],
|
| )
|
|
|
|
|
| async def get_swarm_engine() -> SwarmEngine:
|
| if swarm_engine is None:
|
| raise HTTPException(status_code=503, detail="Swarm engine not initialized")
|
| return swarm_engine
|
|
|
| async def get_inference_engine() -> InferenceEngine:
|
| if inference_engine is None:
|
| raise HTTPException(status_code=503, detail="Inference engine not initialized")
|
| return inference_engine
|
|
|
| @app.get("/health", response_model=HealthResponse)
|
| async def health_check(swarm: SwarmEngine = Depends(get_swarm_engine)):
|
| """Health check endpoint"""
|
| try:
|
| swarm_status = swarm.get_status()
|
| system_info = {
|
| "cuda_available": torch.cuda.is_available(),
|
| "cuda_device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0,
|
| "python_version": "3.8+",
|
| }
|
|
|
| return HealthResponse(
|
| status="healthy",
|
| swarm_status=swarm_status,
|
| system_info=system_info,
|
| timestamp=time.time()
|
| )
|
| except Exception as e:
|
| raise HTTPException(status_code=500, detail=f"Health check failed: {str(e)}")
|
|
|
| @app.get("/model/info", response_model=ModelInfo)
|
| async def get_model_info(swarm: SwarmEngine = Depends(get_swarm_engine)):
|
| """Get model information"""
|
| try:
|
| info = swarm.get_model_info()
|
| memory_stats = swarm.memory_manager.get_memory_stats()
|
|
|
| return ModelInfo(
|
| total_parameters=info.get("total_parameters", 7000000000),
|
| active_encoders=info.get("active_encoders", 100),
|
| total_encoders=info.get("total_encoders", 100),
|
| memory_usage={
|
| "system_memory_gb": memory_stats.used_memory,
|
| "gpu_memory_gb": memory_stats.gpu_memory,
|
| "cache_size_gb": memory_stats.cache_size
|
| },
|
| device_info=info.get("devices", ["cuda:0" if torch.cuda.is_available() else "cpu"])
|
| )
|
| except Exception as e:
|
| raise HTTPException(status_code=500, detail=f"Failed to get model info: {str(e)}")
|
|
|
| @app.post("/generate", response_model=GenerationResponse)
|
| async def generate_text(
|
| request: GenerationRequest,
|
| inference: InferenceEngine = Depends(get_inference_engine)
|
| ):
|
| """Generate text from prompt"""
|
| try:
|
| start_time = time.time()
|
|
|
|
|
| result = await asyncio.get_event_loop().run_in_executor(
|
| None,
|
| inference.generate,
|
| request.prompt,
|
| {
|
| "max_length": request.max_length,
|
| "temperature": request.temperature,
|
| "top_p": request.top_p,
|
| "top_k": request.top_k,
|
| "repetition_penalty": request.repetition_penalty,
|
| "domain": request.domain
|
| }
|
| )
|
|
|
| generation_time = time.time() - start_time
|
|
|
| return GenerationResponse(
|
| generated_text=result["generated_text"],
|
| prompt=request.prompt,
|
| generation_time=generation_time,
|
| tokens_generated=result.get("tokens_generated", 0),
|
| model_info=result.get("model_info", {})
|
| )
|
|
|
| except Exception as e:
|
| raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
|
|
|
| @app.post("/generate/stream")
|
| async def generate_text_stream(
|
| request: GenerationRequest,
|
| inference: InferenceEngine = Depends(get_inference_engine)
|
| ):
|
| """Generate text with streaming response"""
|
| if not request.stream:
|
| raise HTTPException(status_code=400, detail="Streaming not requested")
|
|
|
| async def generate_stream() -> AsyncGenerator[str, None]:
|
| try:
|
|
|
| generator = inference.generate_stream(
|
| request.prompt,
|
| {
|
| "max_length": request.max_length,
|
| "temperature": request.temperature,
|
| "top_p": request.top_p,
|
| "top_k": request.top_k,
|
| "repetition_penalty": request.repetition_penalty,
|
| "domain": request.domain
|
| }
|
| )
|
|
|
| for token_data in generator:
|
| streaming_token = StreamingToken(
|
| token=token_data.get("token", ""),
|
| is_final=token_data.get("is_final", False),
|
| metadata=token_data.get("metadata", {})
|
| )
|
|
|
| yield f"data: {streaming_token.json()}\n\n"
|
|
|
| if streaming_token.is_final:
|
| break
|
|
|
| except Exception as e:
|
| error_token = StreamingToken(
|
| token="",
|
| is_final=True,
|
| metadata={"error": str(e)}
|
| )
|
| yield f"data: {error_token.json()}\n\n"
|
|
|
| return StreamingResponse(
|
| generate_stream(),
|
| media_type="text/plain",
|
| headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
|
| )
|
|
|
| @app.post("/generate/batch")
|
| async def generate_batch(
|
| requests: List[GenerationRequest],
|
| inference: InferenceEngine = Depends(get_inference_engine)
|
| ):
|
| """Generate text for multiple prompts"""
|
| if len(requests) > 10:
|
| raise HTTPException(status_code=400, detail="Batch size too large (max 10)")
|
|
|
| try:
|
|
|
| tasks = []
|
| for req in requests:
|
| task = asyncio.get_event_loop().run_in_executor(
|
| None,
|
| inference.generate,
|
| req.prompt,
|
| {
|
| "max_length": req.max_length,
|
| "temperature": req.temperature,
|
| "top_p": req.top_p,
|
| "top_k": req.top_k,
|
| "repetition_penalty": req.repetition_penalty,
|
| "domain": req.domain
|
| }
|
| )
|
| tasks.append(task)
|
|
|
| results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
| responses = []
|
| for i, (req, result) in enumerate(zip(requests, results)):
|
| if isinstance(result, Exception):
|
| responses.append({
|
| "error": str(result),
|
| "prompt": req.prompt,
|
| "index": i
|
| })
|
| else:
|
| responses.append(GenerationResponse(
|
| generated_text=result["generated_text"],
|
| prompt=req.prompt,
|
| generation_time=result.get("generation_time", 0),
|
| tokens_generated=result.get("tokens_generated", 0),
|
| model_info=result.get("model_info", {})
|
| ))
|
|
|
| return {"responses": responses}
|
|
|
| except Exception as e:
|
| raise HTTPException(status_code=500, detail=f"Batch generation failed: {str(e)}")
|
|
|
| @app.get("/metrics")
|
| async def get_metrics(swarm: SwarmEngine = Depends(get_swarm_engine)):
|
| """Get system metrics"""
|
| try:
|
| metrics = {
|
| "memory_report": swarm.memory_manager.get_memory_report(),
|
| "swarm_metrics": swarm.get_metrics(),
|
| "inference_stats": swarm.get_inference_stats() if hasattr(swarm, 'get_inference_stats') else {},
|
| "timestamp": time.time()
|
| }
|
| return metrics
|
| except Exception as e:
|
| raise HTTPException(status_code=500, detail=f"Failed to get metrics: {str(e)}")
|
|
|
| @app.post("/admin/reload")
|
| async def reload_model(
|
| background_tasks: BackgroundTasks,
|
| swarm: SwarmEngine = Depends(get_swarm_engine)
|
| ):
|
| """Reload the model (admin endpoint)"""
|
| try:
|
| background_tasks.add_task(swarm.reload_model)
|
| return {"message": "Model reload initiated"}
|
| except Exception as e:
|
| raise HTTPException(status_code=500, detail=f"Failed to reload model: {str(e)}")
|
|
|
| @app.post("/admin/cleanup")
|
| async def cleanup_memory(swarm: SwarmEngine = Depends(get_swarm_engine)):
|
| """Force memory cleanup (admin endpoint)"""
|
| try:
|
| swarm.memory_manager.cleanup_memory(aggressive=True)
|
| return {"message": "Memory cleanup completed"}
|
| except Exception as e:
|
| raise HTTPException(status_code=500, detail=f"Failed to cleanup memory: {str(e)}")
|
|
|
|
|
| @app.exception_handler(HTTPException)
|
| async def http_exception_handler(request, exc):
|
| return {
|
| "error": exc.detail,
|
| "status_code": exc.status_code,
|
| "timestamp": time.time()
|
| }
|
|
|
| @app.exception_handler(Exception)
|
| async def general_exception_handler(request, exc):
|
| logging.error(f"Unhandled exception: {exc}")
|
| return {
|
| "error": "Internal server error",
|
| "status_code": 500,
|
| "timestamp": time.time()
|
| }
|
|
|
| def run_server(host: str = "0.0.0.0", port: int = 8000, workers: int = 1):
|
| """Run the API server"""
|
| setup_logging()
|
|
|
| config = uvicorn.Config(
|
| app=app,
|
| host=host,
|
| port=port,
|
| workers=workers,
|
| log_level="info",
|
| access_log=True,
|
| reload=False
|
| )
|
|
|
| server = uvicorn.Server(config)
|
| server.run()
|
|
|
| if __name__ == "__main__":
|
| run_server() |