| import time
|
| from collections import defaultdict
|
| from fastapi import Request
|
| from fastapi.responses import JSONResponse
|
| from starlette.middleware.base import BaseHTTPMiddleware
|
| import logging
|
| import asyncio
|
|
|
| logger = logging.getLogger(__name__)
|
|
|
| class RateLimitMiddleware(BaseHTTPMiddleware):
|
| def __init__(
|
| self,
|
| app,
|
| rate_limit_per_minute=10,
|
| rate_limit_window=60,
|
| protected_routes=["/generate", "/api/generate", "/api/generate-with-report"]
|
| ):
|
| super().__init__(app)
|
| self.rate_limit_per_minute = rate_limit_per_minute
|
| self.rate_limit_window = rate_limit_window
|
| self.protected_routes = protected_routes
|
| self.ip_requests = defaultdict(list)
|
| logger.info(f"Rate limit middleware initialized: {rate_limit_per_minute} requests per {rate_limit_window}s")
|
|
|
| async def dispatch(self, request: Request, call_next):
|
| client_ip = request.client.host
|
| current_time = time.time()
|
|
|
|
|
| if any(request.url.path.startswith(route) for route in self.protected_routes):
|
|
|
| self.ip_requests[client_ip] = [t for t in self.ip_requests[client_ip]
|
| if current_time - t < self.rate_limit_window]
|
|
|
|
|
|
|
| if len(self.ip_requests) > 1000 and hash(client_ip) % 100 == 0:
|
| self._cleanup_all_ips(current_time)
|
|
|
|
|
| if len(self.ip_requests[client_ip]) >= self.rate_limit_per_minute:
|
| logger.warning(f"Rate limit exceeded for IP {client_ip} on {request.url.path}")
|
| return JSONResponse(
|
| status_code=429,
|
| content={"detail": "Rate limit exceeded. Please try again later."}
|
| )
|
|
|
|
|
| self.ip_requests[client_ip].append(current_time)
|
|
|
|
|
| response = await call_next(request)
|
| return response
|
|
|
| def _cleanup_all_ips(self, current_time):
|
| """Remove IPs that haven't made requests in the window"""
|
| to_remove = []
|
| for ip, timestamps in self.ip_requests.items():
|
|
|
| if not timestamps or (current_time - timestamps[-1] > self.rate_limit_window):
|
| to_remove.append(ip)
|
| for ip in to_remove:
|
| del self.ip_requests[ip]
|
|
|
| class ConcurrencyLimitMiddleware(BaseHTTPMiddleware):
|
| def __init__(
|
| self,
|
| app,
|
| max_concurrent_requests=5,
|
| timeout=5.0,
|
| protected_routes=None
|
| ):
|
| super().__init__(app)
|
| self.semaphore = asyncio.Semaphore(max_concurrent_requests)
|
| self.timeout = timeout
|
| self.protected_routes = protected_routes or ["/generate", "/api/generate", "/api/generate-with-report"]
|
| logger.info(f"Concurrency limit middleware initialized: {max_concurrent_requests} concurrent requests")
|
|
|
| async def dispatch(self, request, call_next):
|
| try:
|
|
|
| if any(request.url.path.startswith(route) for route in self.protected_routes):
|
| try:
|
|
|
| acquired = False
|
| try:
|
|
|
| await asyncio.wait_for(self.semaphore.acquire(), timeout=self.timeout)
|
| acquired = True
|
| return await call_next(request)
|
| finally:
|
| if acquired:
|
| self.semaphore.release()
|
| except asyncio.TimeoutError:
|
|
|
| logger.warning(f"Concurrency limit reached for {request.url.path}")
|
| return JSONResponse(
|
| status_code=503,
|
| content={"detail": "Server is at capacity. Please try again later."}
|
| )
|
| else:
|
|
|
| return await call_next(request)
|
| except Exception as e:
|
| logger.error(f"Error in ConcurrencyLimitMiddleware: {str(e)}")
|
| return JSONResponse(
|
| status_code=500,
|
| content={"detail": f"Internal server error in middleware: {str(e)}"}
|
| )
|
|
|
|
|
|
|
| class RequestSizeLimitMiddleware(BaseHTTPMiddleware):
|
| def __init__(self, app, max_content_length=1024*1024):
|
| super().__init__(app)
|
| self.max_content_length = max_content_length
|
| logger.info(f"Request size limit middleware initialized: {max_content_length} bytes")
|
|
|
| async def dispatch(self, request: Request, call_next):
|
| content_length = request.headers.get('content-length')
|
| if content_length:
|
| if int(content_length) > self.max_content_length:
|
| logger.warning(f"Request too large: {content_length} bytes")
|
| return JSONResponse(
|
| status_code=413,
|
| content={"detail": "Request too large"}
|
| )
|
| return await call_next(request)
|
|
|