from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException from pydantic import BaseModel, field_validator from typing import List from fraud_model import FraudDetector import uvicorn import logging # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Global detector instance detector: FraudDetector = None @asynccontextmanager async def lifespan(app: FastAPI): global detector try: logger.info("Loading FraudDetector model...") detector = FraudDetector() logger.info("FraudDetector loaded successfully.") except Exception as e: logger.error(f"FATAL: Failed to initialize FraudDetector: {e}") raise RuntimeError(f"Model failed to load: {e}") yield detector = None logger.info("FraudDetector shut down.") app = FastAPI( title="Bank Fraud Detection API", description="API for detecting fraudulent bank transactions using AI.", version="1.0.0", lifespan=lifespan ) # --- Request / Response Models --- class PredictionRequest(BaseModel): text: str @field_validator("text") @classmethod def text_must_not_be_empty(cls, v): if not v or not v.strip(): raise ValueError("text must not be empty") return v.strip() class BatchPredictionRequest(BaseModel): texts: List[str] @field_validator("texts") @classmethod def texts_must_not_be_empty(cls, v): if not v: raise ValueError("texts list must not be empty") cleaned = [t.strip() for t in v if t and t.strip()] if not cleaned: raise ValueError("texts list contains only empty strings") return cleaned class PredictionResponse(BaseModel): text: str fraud_score: float risk_level: str class AnalyzeResponse(BaseModel): text: str fraud_score: float risk_level: str is_fraud: bool detection: str # --- Routes --- @app.get("/health") def health_check(): if detector: return {"status": "healthy", "model": detector.model_name} return {"status": "unhealthy", "error": "Model not loaded"} @app.post("/predict", response_model=PredictionResponse) def predict_single(request: PredictionRequest): if not detector: raise HTTPException(status_code=503, detail="Model service unavailable") try: result = detector.predict(request.text) return result except Exception as e: logger.error(f"Prediction error: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/predict/batch", response_model=List[PredictionResponse]) def predict_batch(request: BatchPredictionRequest): if not detector: raise HTTPException(status_code=503, detail="Model service unavailable") try: results = detector.predict_batch(request.texts) return results except Exception as e: logger.error(f"Batch prediction error: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/analyze", response_model=AnalyzeResponse) def analyze(request: PredictionRequest): if not detector: raise HTTPException(status_code=503, detail="Model service unavailable") try: result = detector.analyze(request.text) return result except Exception as e: logger.error(f"Analyze error: {e}") raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)