Spaces:
Runtime error
Runtime error
| """ | |
| Session Manager - Multi-session rotation and load balancing | |
| Handles multiple Pyrogram sessions to maximize bandwidth and avoid flood limits | |
| """ | |
| import asyncio | |
| import logging | |
| import os | |
| from typing import List, Optional, AsyncGenerator | |
| from io import BytesIO | |
| from pyrogram import Client | |
| from pyrogram.errors import FloodWait, BadRequest | |
| from pyrogram.types import Message | |
| from utils import CHUNK_SIZE | |
| logger = logging.getLogger(__name__) | |
| class TelegramSession: | |
| """Wrapper for a single Pyrogram session""" | |
| def __init__( | |
| self, | |
| session_name: str, | |
| api_id: int, | |
| api_hash: str, | |
| session_string: Optional[str] = None | |
| ): | |
| self.session_name = session_name | |
| self.api_id = api_id | |
| self.api_hash = api_hash | |
| self.session_string = session_string | |
| self.client: Optional[Client] = None | |
| self.is_active = False | |
| self.upload_count = 0 | |
| self.download_count = 0 | |
| async def initialize(self): | |
| """Initialize and start the Pyrogram client""" | |
| try: | |
| if self.session_string: | |
| self.client = Client( | |
| name=self.session_name, | |
| api_id=self.api_id, | |
| api_hash=self.api_hash, | |
| session_string=self.session_string, | |
| in_memory=True, | |
| no_updates=True | |
| ) | |
| else: | |
| self.client = Client( | |
| name=self.session_name, | |
| api_id=self.api_id, | |
| api_hash=self.api_hash, | |
| in_memory=True, | |
| no_updates=True | |
| ) | |
| await self.client.start() | |
| self.is_active = True | |
| me = await self.client.get_me() | |
| logger.info( | |
| f"Session {self.session_name} initialized: " | |
| f"@{me.username or me.first_name}" | |
| ) | |
| except Exception as e: | |
| logger.error(f"Failed to initialize session {self.session_name}: {str(e)}") | |
| self.is_active = False | |
| raise | |
| async def cleanup(self): | |
| """Stop and cleanup the session""" | |
| if self.client and self.is_active: | |
| try: | |
| await self.client.stop() | |
| self.is_active = False | |
| logger.info(f"Session {self.session_name} stopped") | |
| except Exception as e: | |
| logger.error(f"Error stopping session {self.session_name}: {str(e)}") | |
| class SessionManager: | |
| """Manages multiple Telegram sessions for load balancing""" | |
| def __init__(self): | |
| self.sessions: List[TelegramSession] = [] | |
| self.bot_token: Optional[str] = None | |
| self.bot_session: Optional[TelegramSession] = None | |
| self.current_upload_index = 0 | |
| self.current_download_index = 0 | |
| self.lock = asyncio.Lock() | |
| async def initialize(self): | |
| """Initialize all sessions from environment variables""" | |
| logger.info("Initializing Session Manager...") | |
| # Get bot token | |
| self.bot_token = os.getenv("BOT_TOKEN") | |
| if not self.bot_token: | |
| raise ValueError("BOT_TOKEN environment variable is required") | |
| # Initialize bot session | |
| api_id = int(os.getenv("API_ID", "0")) | |
| api_hash = os.getenv("API_HASH", "") | |
| if not api_id or not api_hash: | |
| raise ValueError("API_ID and API_HASH environment variables are required") | |
| self.bot_session = TelegramSession( | |
| session_name="bot_session", | |
| api_id=api_id, | |
| api_hash=api_hash | |
| ) | |
| # Override with bot token | |
| self.bot_session.client = Client( | |
| name="bot_session", | |
| api_id=api_id, | |
| api_hash=api_hash, | |
| bot_token=self.bot_token, | |
| in_memory=True, | |
| no_updates=True | |
| ) | |
| await self.bot_session.client.start() | |
| self.bot_session.is_active = True | |
| logger.info("Bot session initialized") | |
| # Initialize user sessions from SESSION_STRINGS | |
| session_strings = os.getenv("SESSION_STRINGS", "").split(",") | |
| session_strings = [s.strip() for s in session_strings if s.strip()] | |
| if not session_strings: | |
| logger.warning("No SESSION_STRINGS found, using bot session only") | |
| self.sessions = [self.bot_session] | |
| return | |
| # Create user sessions | |
| for i, session_string in enumerate(session_strings): | |
| session = TelegramSession( | |
| session_name=f"user_session_{i}", | |
| api_id=api_id, | |
| api_hash=api_hash, | |
| session_string=session_string | |
| ) | |
| try: | |
| await session.initialize() | |
| self.sessions.append(session) | |
| except Exception as e: | |
| logger.error(f"Failed to initialize user session {i}: {str(e)}") | |
| if not self.sessions: | |
| # Fallback to bot session | |
| self.sessions = [self.bot_session] | |
| logger.info(f"Session Manager initialized with {len(self.sessions)} session(s)") | |
| def get_next_upload_session(self) -> TelegramSession: | |
| """Get next session for upload (round-robin)""" | |
| if not self.sessions: | |
| raise RuntimeError("No active sessions available") | |
| session = self.sessions[self.current_upload_index] | |
| self.current_upload_index = (self.current_upload_index + 1) % len(self.sessions) | |
| return session | |
| def get_next_download_session(self) -> TelegramSession: | |
| """Get next session for download (round-robin)""" | |
| if not self.sessions: | |
| raise RuntimeError("No active sessions available") | |
| session = self.sessions[self.current_download_index] | |
| self.current_download_index = ( | |
| self.current_download_index + 1 | |
| ) % len(self.sessions) | |
| return session | |
| async def upload_part( | |
| self, | |
| data: bytes, | |
| filename: str, | |
| max_retries: int = 3 | |
| ) -> str: | |
| """ | |
| Upload a file part to Telegram | |
| Returns: file_id for later retrieval | |
| """ | |
| retry_count = 0 | |
| while retry_count < max_retries: | |
| session = self.get_next_upload_session() | |
| if not session.is_active or not session.client: | |
| retry_count += 1 | |
| continue | |
| try: | |
| # Upload to "Saved Messages" (self chat) | |
| message: Message = await session.client.send_document( | |
| chat_id="me", | |
| document=BytesIO(data), | |
| file_name=filename, | |
| force_document=True | |
| ) | |
| session.upload_count += 1 | |
| file_id = message.document.file_id | |
| logger.debug( | |
| f"Part uploaded via {session.session_name}: " | |
| f"file_id={file_id}, size={len(data)}" | |
| ) | |
| return file_id | |
| except FloodWait as e: | |
| logger.warning( | |
| f"FloodWait on {session.session_name}: waiting {e.value}s" | |
| ) | |
| await asyncio.sleep(e.value) | |
| retry_count += 1 | |
| except Exception as e: | |
| logger.error( | |
| f"Upload failed on {session.session_name}: {str(e)}" | |
| ) | |
| retry_count += 1 | |
| if retry_count < max_retries: | |
| await asyncio.sleep(2 ** retry_count) | |
| raise RuntimeError(f"Failed to upload part after {max_retries} retries") | |
| async def stream_part( | |
| self, | |
| file_id: str, | |
| offset: int = 0, | |
| limit: Optional[int] = None | |
| ) -> AsyncGenerator[bytes, None]: | |
| """ | |
| Stream a file part from Telegram | |
| Yields chunks of data | |
| """ | |
| session = self.get_next_download_session() | |
| if not session.is_active or not session.client: | |
| raise RuntimeError("No active session available for streaming") | |
| try: | |
| bytes_read = 0 | |
| async for chunk in session.client.stream_media( | |
| file_id, | |
| offset=offset, | |
| limit=limit or 0 | |
| ): | |
| if limit and bytes_read + len(chunk) > limit: | |
| # Trim final chunk | |
| yield chunk[:limit - bytes_read] | |
| break | |
| yield chunk | |
| bytes_read += len(chunk) | |
| if limit and bytes_read >= limit: | |
| break | |
| session.download_count += 1 | |
| logger.debug( | |
| f"Part streamed via {session.session_name}: " | |
| f"file_id={file_id}, bytes={bytes_read}" | |
| ) | |
| except FloodWait as e: | |
| logger.warning(f"FloodWait on download: waiting {e.value}s") | |
| await asyncio.sleep(e.value) | |
| # Retry with next session | |
| async for chunk in self.stream_part(file_id, offset, limit): | |
| yield chunk | |
| except Exception as e: | |
| logger.error(f"Stream failed: {str(e)}") | |
| raise | |
| async def delete_part(self, file_id: str) -> bool: | |
| """Delete a file part from Telegram""" | |
| session = self.get_next_upload_session() | |
| if not session.is_active or not session.client: | |
| return False | |
| try: | |
| # Get message and delete it | |
| # Note: This requires the message_id, which we don't store | |
| # For production, consider storing message_ids in metadata | |
| logger.warning("Delete operation requires message_id, not implemented") | |
| return False | |
| except Exception as e: | |
| logger.error(f"Delete failed: {str(e)}") | |
| return False | |
| async def cleanup(self): | |
| """Cleanup all sessions""" | |
| logger.info("Cleaning up Session Manager...") | |
| for session in self.sessions: | |
| await session.cleanup() | |
| if self.bot_session and self.bot_session != self.sessions[0]: | |
| await self.bot_session.cleanup() | |
| logger.info("Session Manager cleanup complete") | |
| def get_stats(self) -> dict: | |
| """Get session statistics""" | |
| return { | |
| "total_sessions": len(self.sessions), | |
| "active_sessions": sum(1 for s in self.sessions if s.is_active), | |
| "sessions": [ | |
| { | |
| "name": s.session_name, | |
| "active": s.is_active, | |
| "uploads": s.upload_count, | |
| "downloads": s.download_count | |
| } | |
| for s in self.sessions | |
| ] | |
| } | |