| """ |
| WebSocket Support Module |
| Provides real-time updates via WebSocket connections with connection management |
| """ |
|
|
| import asyncio |
| import json |
| from datetime import datetime |
| from typing import Set, Dict, Any, Optional, List |
| from fastapi import WebSocket, WebSocketDisconnect, APIRouter |
| from starlette.websockets import WebSocketState |
| from utils.logger import setup_logger |
| from database.db_manager import db_manager |
| from monitoring.rate_limiter import rate_limiter |
| from config import config |
|
|
| |
| logger = setup_logger("websocket", level="INFO") |
|
|
| |
| router = APIRouter() |
|
|
|
|
| class ConnectionManager: |
| """ |
| Manages WebSocket connections and broadcasts messages to all connected clients |
| """ |
|
|
| def __init__(self): |
| """Initialize connection manager""" |
| self.active_connections: Set[WebSocket] = set() |
| self.connection_metadata: Dict[WebSocket, Dict[str, Any]] = {} |
| self._broadcast_task: Optional[asyncio.Task] = None |
| self._heartbeat_task: Optional[asyncio.Task] = None |
| self._is_running = False |
|
|
| async def connect(self, websocket: WebSocket, client_id: str = None): |
| """ |
| Accept and register a new WebSocket connection |
| |
| Args: |
| websocket: WebSocket connection |
| client_id: Optional client identifier |
| """ |
| await websocket.accept() |
| self.active_connections.add(websocket) |
|
|
| |
| self.connection_metadata[websocket] = { |
| 'client_id': client_id or f"client_{id(websocket)}", |
| 'connected_at': datetime.utcnow().isoformat(), |
| 'last_ping': datetime.utcnow().isoformat() |
| } |
|
|
| logger.info( |
| f"WebSocket connected: {self.connection_metadata[websocket]['client_id']} " |
| f"(Total connections: {len(self.active_connections)})" |
| ) |
|
|
| |
| await self.send_personal_message( |
| { |
| 'type': 'connection_established', |
| 'client_id': self.connection_metadata[websocket]['client_id'], |
| 'timestamp': datetime.utcnow().isoformat(), |
| 'message': 'Connected to Crypto API Monitor WebSocket' |
| }, |
| websocket |
| ) |
|
|
| def disconnect(self, websocket: WebSocket): |
| """ |
| Unregister and close a WebSocket connection |
| |
| Args: |
| websocket: WebSocket connection to disconnect |
| """ |
| if websocket in self.active_connections: |
| client_id = self.connection_metadata.get(websocket, {}).get('client_id', 'unknown') |
| self.active_connections.remove(websocket) |
|
|
| if websocket in self.connection_metadata: |
| del self.connection_metadata[websocket] |
|
|
| logger.info( |
| f"WebSocket disconnected: {client_id} " |
| f"(Remaining connections: {len(self.active_connections)})" |
| ) |
|
|
| async def send_personal_message(self, message: Dict[str, Any], websocket: WebSocket): |
| """ |
| Send a message to a specific WebSocket connection |
| |
| Args: |
| message: Message dictionary to send |
| websocket: Target WebSocket connection |
| """ |
| try: |
| if websocket.client_state == WebSocketState.CONNECTED: |
| await websocket.send_json(message) |
| except Exception as e: |
| logger.error(f"Error sending personal message: {e}") |
| self.disconnect(websocket) |
|
|
| async def broadcast(self, message: Dict[str, Any]): |
| """ |
| Broadcast a message to all connected clients |
| |
| Args: |
| message: Message dictionary to broadcast |
| """ |
| disconnected = [] |
|
|
| for connection in self.active_connections.copy(): |
| try: |
| if connection.client_state == WebSocketState.CONNECTED: |
| await connection.send_json(message) |
| else: |
| disconnected.append(connection) |
| except Exception as e: |
| logger.error(f"Error broadcasting to client: {e}") |
| disconnected.append(connection) |
|
|
| |
| for connection in disconnected: |
| self.disconnect(connection) |
|
|
| async def broadcast_status_update(self): |
| """ |
| Broadcast system status update to all connected clients |
| """ |
| try: |
| |
| latest_metrics = db_manager.get_latest_system_metrics() |
|
|
| |
| providers = config.get_all_providers() |
|
|
| |
| rate_limit_statuses = rate_limiter.get_all_statuses() |
|
|
| |
| alerts = db_manager.get_alerts(acknowledged=False, hours=1) |
|
|
| |
| message = { |
| 'type': 'status_update', |
| 'timestamp': datetime.utcnow().isoformat(), |
| 'system_metrics': { |
| 'total_providers': latest_metrics.total_providers if latest_metrics else len(providers), |
| 'online_count': latest_metrics.online_count if latest_metrics else 0, |
| 'degraded_count': latest_metrics.degraded_count if latest_metrics else 0, |
| 'offline_count': latest_metrics.offline_count if latest_metrics else 0, |
| 'avg_response_time_ms': latest_metrics.avg_response_time_ms if latest_metrics else 0, |
| 'total_requests_hour': latest_metrics.total_requests_hour if latest_metrics else 0, |
| 'total_failures_hour': latest_metrics.total_failures_hour if latest_metrics else 0, |
| 'system_health': latest_metrics.system_health if latest_metrics else 'unknown' |
| }, |
| 'alert_count': len(alerts), |
| 'active_websocket_clients': len(self.active_connections) |
| } |
|
|
| await self.broadcast(message) |
| logger.debug(f"Broadcasted status update to {len(self.active_connections)} clients") |
|
|
| except Exception as e: |
| logger.error(f"Error broadcasting status update: {e}", exc_info=True) |
|
|
| async def broadcast_new_log_entry(self, log_type: str, log_data: Dict[str, Any]): |
| """ |
| Broadcast a new log entry |
| |
| Args: |
| log_type: Type of log (connection, failure, collection, rate_limit) |
| log_data: Log data dictionary |
| """ |
| try: |
| message = { |
| 'type': 'new_log_entry', |
| 'timestamp': datetime.utcnow().isoformat(), |
| 'log_type': log_type, |
| 'data': log_data |
| } |
|
|
| await self.broadcast(message) |
| logger.debug(f"Broadcasted new {log_type} log entry") |
|
|
| except Exception as e: |
| logger.error(f"Error broadcasting log entry: {e}", exc_info=True) |
|
|
| async def broadcast_rate_limit_alert(self, provider_name: str, percentage: float): |
| """ |
| Broadcast rate limit alert |
| |
| Args: |
| provider_name: Provider name |
| percentage: Current usage percentage |
| """ |
| try: |
| message = { |
| 'type': 'rate_limit_alert', |
| 'timestamp': datetime.utcnow().isoformat(), |
| 'provider': provider_name, |
| 'percentage': percentage, |
| 'severity': 'critical' if percentage >= 95 else 'warning' |
| } |
|
|
| await self.broadcast(message) |
| logger.info(f"Broadcasted rate limit alert for {provider_name} ({percentage}%)") |
|
|
| except Exception as e: |
| logger.error(f"Error broadcasting rate limit alert: {e}", exc_info=True) |
|
|
| async def broadcast_provider_status_change( |
| self, |
| provider_name: str, |
| old_status: str, |
| new_status: str, |
| details: Optional[Dict] = None |
| ): |
| """ |
| Broadcast provider status change |
| |
| Args: |
| provider_name: Provider name |
| old_status: Previous status |
| new_status: New status |
| details: Optional details about the change |
| """ |
| try: |
| message = { |
| 'type': 'provider_status_change', |
| 'timestamp': datetime.utcnow().isoformat(), |
| 'provider': provider_name, |
| 'old_status': old_status, |
| 'new_status': new_status, |
| 'details': details or {} |
| } |
|
|
| await self.broadcast(message) |
| logger.info( |
| f"Broadcasted provider status change: {provider_name} " |
| f"{old_status} -> {new_status}" |
| ) |
|
|
| except Exception as e: |
| logger.error(f"Error broadcasting provider status change: {e}", exc_info=True) |
|
|
| async def _periodic_broadcast_loop(self): |
| """ |
| Background task that broadcasts updates every 10 seconds |
| """ |
| logger.info("Starting periodic broadcast loop") |
|
|
| while self._is_running: |
| try: |
| |
| await self.broadcast_status_update() |
|
|
| |
| rate_limit_statuses = rate_limiter.get_all_statuses() |
| for provider, status_data in rate_limit_statuses.items(): |
| if status_data and status_data.get('percentage', 0) >= 80: |
| await self.broadcast_rate_limit_alert( |
| provider, |
| status_data['percentage'] |
| ) |
|
|
| |
| await asyncio.sleep(10) |
|
|
| except Exception as e: |
| logger.error(f"Error in periodic broadcast loop: {e}", exc_info=True) |
| await asyncio.sleep(10) |
|
|
| logger.info("Periodic broadcast loop stopped") |
|
|
| async def _heartbeat_loop(self): |
| """ |
| Background task that sends heartbeat pings to all clients |
| """ |
| logger.info("Starting heartbeat loop") |
|
|
| while self._is_running: |
| try: |
| |
| ping_message = { |
| 'type': 'ping', |
| 'timestamp': datetime.utcnow().isoformat() |
| } |
|
|
| await self.broadcast(ping_message) |
|
|
| |
| await asyncio.sleep(30) |
|
|
| except Exception as e: |
| logger.error(f"Error in heartbeat loop: {e}", exc_info=True) |
| await asyncio.sleep(30) |
|
|
| logger.info("Heartbeat loop stopped") |
|
|
| async def start_background_tasks(self): |
| """ |
| Start background broadcast and heartbeat tasks |
| """ |
| if self._is_running: |
| logger.warning("Background tasks already running") |
| return |
|
|
| self._is_running = True |
|
|
| |
| self._broadcast_task = asyncio.create_task(self._periodic_broadcast_loop()) |
| logger.info("Started periodic broadcast task") |
|
|
| |
| self._heartbeat_task = asyncio.create_task(self._heartbeat_loop()) |
| logger.info("Started heartbeat task") |
|
|
| async def stop_background_tasks(self): |
| """ |
| Stop background broadcast and heartbeat tasks |
| """ |
| if not self._is_running: |
| logger.warning("Background tasks not running") |
| return |
|
|
| self._is_running = False |
|
|
| |
| if self._broadcast_task: |
| self._broadcast_task.cancel() |
| try: |
| await self._broadcast_task |
| except asyncio.CancelledError: |
| pass |
| logger.info("Stopped periodic broadcast task") |
|
|
| |
| if self._heartbeat_task: |
| self._heartbeat_task.cancel() |
| try: |
| await self._heartbeat_task |
| except asyncio.CancelledError: |
| pass |
| logger.info("Stopped heartbeat task") |
|
|
| async def close_all_connections(self): |
| """ |
| Close all active WebSocket connections |
| """ |
| logger.info(f"Closing {len(self.active_connections)} active connections") |
|
|
| for connection in self.active_connections.copy(): |
| try: |
| if connection.client_state == WebSocketState.CONNECTED: |
| await connection.close(code=1000, reason="Server shutdown") |
| except Exception as e: |
| logger.error(f"Error closing connection: {e}") |
|
|
| self.active_connections.clear() |
| self.connection_metadata.clear() |
| logger.info("All WebSocket connections closed") |
|
|
| def get_connection_count(self) -> int: |
| """ |
| Get the number of active connections |
| |
| Returns: |
| Number of active connections |
| """ |
| return len(self.active_connections) |
|
|
| def get_connection_info(self) -> List[Dict[str, Any]]: |
| """ |
| Get information about all active connections |
| |
| Returns: |
| List of connection metadata dictionaries |
| """ |
| return [ |
| { |
| 'client_id': metadata['client_id'], |
| 'connected_at': metadata['connected_at'], |
| 'last_ping': metadata['last_ping'] |
| } |
| for metadata in self.connection_metadata.values() |
| ] |
|
|
|
|
| |
| manager = ConnectionManager() |
|
|
|
|
| @router.websocket("/ws/live") |
| async def websocket_live_endpoint(websocket: WebSocket): |
| """ |
| WebSocket endpoint for real-time updates |
| |
| Provides: |
| - System status updates every 10 seconds |
| - Real-time log entries |
| - Rate limit alerts |
| - Provider status changes |
| - Heartbeat pings every 30 seconds |
| |
| Message Types: |
| - connection_established: Sent when client connects |
| - status_update: Periodic system status (every 10s) |
| - new_log_entry: New log entry notification |
| - rate_limit_alert: Rate limit warning |
| - provider_status_change: Provider status change |
| - ping: Heartbeat ping (every 30s) |
| """ |
| client_id = None |
|
|
| try: |
| |
| await manager.connect(websocket) |
| client_id = manager.connection_metadata.get(websocket, {}).get('client_id', 'unknown') |
|
|
| |
| if not manager._is_running: |
| await manager.start_background_tasks() |
|
|
| |
| while True: |
| try: |
| |
| data = await websocket.receive_text() |
|
|
| |
| try: |
| message = json.loads(data) |
|
|
| |
| if message.get('type') == 'pong': |
| if websocket in manager.connection_metadata: |
| manager.connection_metadata[websocket]['last_ping'] = datetime.utcnow().isoformat() |
| logger.debug(f"Received pong from {client_id}") |
|
|
| |
| elif message.get('type') == 'subscribe': |
| |
| logger.debug(f"Client {client_id} subscription request: {message}") |
|
|
| |
| elif message.get('type') == 'unsubscribe': |
| logger.debug(f"Client {client_id} unsubscribe request: {message}") |
|
|
| except json.JSONDecodeError: |
| logger.warning(f"Received invalid JSON from {client_id}: {data}") |
|
|
| except WebSocketDisconnect: |
| logger.info(f"Client {client_id} disconnected") |
| break |
|
|
| except Exception as e: |
| logger.error(f"Error handling message from {client_id}: {e}", exc_info=True) |
| break |
|
|
| except Exception as e: |
| logger.error(f"WebSocket error for {client_id}: {e}", exc_info=True) |
|
|
| finally: |
| |
| manager.disconnect(websocket) |
|
|
|
|
| @router.get("/ws/stats") |
| async def websocket_stats(): |
| """ |
| Get WebSocket connection statistics |
| |
| Returns: |
| Dictionary with connection stats |
| """ |
| return { |
| 'active_connections': manager.get_connection_count(), |
| 'connections': manager.get_connection_info(), |
| 'background_tasks_running': manager._is_running, |
| 'timestamp': datetime.utcnow().isoformat() |
| } |
|
|
|
|
| |
| __all__ = ['router', 'manager', 'ConnectionManager'] |
|
|