Spaces:
Sleeping
Sleeping
| """Token manager for Flow2API with AT auto-refresh""" | |
| import asyncio | |
| from datetime import datetime, timedelta, timezone | |
| from typing import Optional, List | |
| from ..core.database import Database | |
| from ..core.models import Token, Project | |
| from ..core.logger import debug_logger | |
| from .flow_client import FlowClient | |
| from .proxy_manager import ProxyManager | |
| class TokenManager: | |
| """Token lifecycle manager with AT auto-refresh""" | |
| def __init__(self, db: Database, flow_client: FlowClient): | |
| self.db = db | |
| self.flow_client = flow_client | |
| self._lock = asyncio.Lock() | |
| # ========== Token CRUD ========== | |
| async def get_all_tokens(self) -> List[Token]: | |
| """Get all tokens""" | |
| return await self.db.get_all_tokens() | |
| async def get_active_tokens(self) -> List[Token]: | |
| """Get all active tokens""" | |
| return await self.db.get_active_tokens() | |
| async def get_token(self, token_id: int) -> Optional[Token]: | |
| """Get token by ID""" | |
| return await self.db.get_token(token_id) | |
| async def delete_token(self, token_id: int): | |
| """Delete token""" | |
| await self.db.delete_token(token_id) | |
| async def enable_token(self, token_id: int): | |
| """Enable a token and reset error count""" | |
| # Enable the token | |
| await self.db.update_token(token_id, is_active=True) | |
| # Reset error count when enabling (only reset total error_count, keep today_error_count) | |
| await self.db.reset_error_count(token_id) | |
| async def disable_token(self, token_id: int): | |
| """Disable a token""" | |
| await self.db.update_token(token_id, is_active=False) | |
| # ========== Token添加 (支持Project创建) ========== | |
| async def add_token( | |
| self, | |
| st: str, | |
| project_id: Optional[str] = None, | |
| project_name: Optional[str] = None, | |
| remark: Optional[str] = None, | |
| image_enabled: bool = True, | |
| video_enabled: bool = True, | |
| image_concurrency: int = -1, | |
| video_concurrency: int = -1 | |
| ) -> Token: | |
| """Add a new token | |
| Args: | |
| st: Session Token (必需) | |
| project_id: 项目ID (可选,如果提供则直接使用,不创建新项目) | |
| project_name: 项目名称 (可选,如果不提供则自动生成) | |
| remark: 备注 | |
| image_enabled: 是否启用图片生成 | |
| video_enabled: 是否启用视频生成 | |
| image_concurrency: 图片并发限制 | |
| video_concurrency: 视频并发限制 | |
| Returns: | |
| Token object | |
| """ | |
| # Step 1: 检查ST是否已存在 | |
| existing_token = await self.db.get_token_by_st(st) | |
| if existing_token: | |
| raise ValueError(f"Token 已存在(邮箱: {existing_token.email})") | |
| # Step 2: 使用ST转换AT | |
| debug_logger.log_info(f"[ADD_TOKEN] Converting ST to AT...") | |
| try: | |
| result = await self.flow_client.st_to_at(st) | |
| at = result["access_token"] | |
| expires = result.get("expires") | |
| user_info = result.get("user", {}) | |
| email = user_info.get("email", "") | |
| name = user_info.get("name", email.split("@")[0] if email else "") | |
| # 解析过期时间 | |
| at_expires = None | |
| if expires: | |
| try: | |
| at_expires = datetime.fromisoformat(expires.replace('Z', '+00:00')) | |
| except: | |
| pass | |
| except Exception as e: | |
| raise ValueError(f"ST转AT失败: {str(e)}") | |
| # Step 3: 查询余额 | |
| try: | |
| credits_result = await self.flow_client.get_credits(at) | |
| credits = credits_result.get("credits", 0) | |
| user_paygate_tier = credits_result.get("userPaygateTier") | |
| except: | |
| credits = 0 | |
| user_paygate_tier = None | |
| # Step 4: 处理Project ID和名称 | |
| if project_id: | |
| # 用户提供了project_id,直接使用 | |
| debug_logger.log_info(f"[ADD_TOKEN] Using provided project_id: {project_id}") | |
| if not project_name: | |
| # 如果没有提供project_name,生成一个 | |
| now = datetime.now() | |
| project_name = now.strftime("%b %d - %H:%M") | |
| else: | |
| # 用户没有提供project_id,需要创建新项目 | |
| if not project_name: | |
| # 自动生成项目名称 | |
| now = datetime.now() | |
| project_name = now.strftime("%b %d - %H:%M") | |
| try: | |
| project_id = await self.flow_client.create_project(st, project_name) | |
| debug_logger.log_info(f"[ADD_TOKEN] Created new project: {project_name} (ID: {project_id})") | |
| except Exception as e: | |
| raise ValueError(f"创建项目失败: {str(e)}") | |
| # Step 5: 创建Token对象 | |
| token = Token( | |
| st=st, | |
| at=at, | |
| at_expires=at_expires, | |
| email=email, | |
| name=name, | |
| remark=remark, | |
| is_active=True, | |
| credits=credits, | |
| user_paygate_tier=user_paygate_tier, | |
| current_project_id=project_id, | |
| current_project_name=project_name, | |
| image_enabled=image_enabled, | |
| video_enabled=video_enabled, | |
| image_concurrency=image_concurrency, | |
| video_concurrency=video_concurrency | |
| ) | |
| # Step 6: 保存到数据库 | |
| token_id = await self.db.add_token(token) | |
| token.id = token_id | |
| # Step 7: 保存Project到数据库 | |
| project = Project( | |
| project_id=project_id, | |
| token_id=token_id, | |
| project_name=project_name, | |
| tool_name="PINHOLE" | |
| ) | |
| await self.db.add_project(project) | |
| debug_logger.log_info(f"[ADD_TOKEN] Token added successfully (ID: {token_id}, Email: {email})") | |
| return token | |
| async def update_token( | |
| self, | |
| token_id: int, | |
| st: Optional[str] = None, | |
| at: Optional[str] = None, | |
| at_expires: Optional[datetime] = None, | |
| project_id: Optional[str] = None, | |
| project_name: Optional[str] = None, | |
| remark: Optional[str] = None, | |
| image_enabled: Optional[bool] = None, | |
| video_enabled: Optional[bool] = None, | |
| image_concurrency: Optional[int] = None, | |
| video_concurrency: Optional[int] = None | |
| ): | |
| """Update token (支持修改project_id和project_name) | |
| 当用户编辑保存token时,如果token未过期,自动清空429禁用状态 | |
| """ | |
| update_fields = {} | |
| if st is not None: | |
| update_fields["st"] = st | |
| if at is not None: | |
| update_fields["at"] = at | |
| if at_expires is not None: | |
| update_fields["at_expires"] = at_expires | |
| if project_id is not None: | |
| update_fields["current_project_id"] = project_id | |
| if project_name is not None: | |
| update_fields["current_project_name"] = project_name | |
| if remark is not None: | |
| update_fields["remark"] = remark | |
| if image_enabled is not None: | |
| update_fields["image_enabled"] = image_enabled | |
| if video_enabled is not None: | |
| update_fields["video_enabled"] = video_enabled | |
| if image_concurrency is not None: | |
| update_fields["image_concurrency"] = image_concurrency | |
| if video_concurrency is not None: | |
| update_fields["video_concurrency"] = video_concurrency | |
| # 检查token是否因429被禁用,如果是且未过期,则清空429状态 | |
| token = await self.db.get_token(token_id) | |
| if token and token.ban_reason == "429_rate_limit": | |
| # 检查token是否过期 | |
| is_expired = False | |
| if token.at_expires: | |
| now = datetime.now(timezone.utc) | |
| if token.at_expires.tzinfo is None: | |
| at_expires_aware = token.at_expires.replace(tzinfo=timezone.utc) | |
| else: | |
| at_expires_aware = token.at_expires | |
| is_expired = at_expires_aware <= now | |
| # 如果未过期,清空429禁用状态 | |
| if not is_expired: | |
| debug_logger.log_info(f"[UPDATE_TOKEN] Token {token_id} 编辑保存,清空429禁用状态") | |
| update_fields["ban_reason"] = None | |
| update_fields["banned_at"] = None | |
| if update_fields: | |
| await self.db.update_token(token_id, **update_fields) | |
| # ========== AT自动刷新逻辑 (核心) ========== | |
| async def is_at_valid(self, token_id: int) -> bool: | |
| """检查AT是否有效,如果无效或即将过期则自动刷新 | |
| Returns: | |
| True if AT is valid or refreshed successfully | |
| False if AT cannot be refreshed | |
| """ | |
| token = await self.db.get_token(token_id) | |
| if not token: | |
| return False | |
| # 如果AT不存在,需要刷新 | |
| if not token.at: | |
| debug_logger.log_info(f"[AT_CHECK] Token {token_id}: AT不存在,需要刷新") | |
| return await self._refresh_at(token_id) | |
| # 如果没有过期时间,假设需要刷新 | |
| if not token.at_expires: | |
| debug_logger.log_info(f"[AT_CHECK] Token {token_id}: AT过期时间未知,尝试刷新") | |
| return await self._refresh_at(token_id) | |
| # 检查是否即将过期 (提前1小时刷新) | |
| now = datetime.now(timezone.utc) | |
| # 确保at_expires也是timezone-aware | |
| if token.at_expires.tzinfo is None: | |
| at_expires_aware = token.at_expires.replace(tzinfo=timezone.utc) | |
| else: | |
| at_expires_aware = token.at_expires | |
| time_until_expiry = at_expires_aware - now | |
| if time_until_expiry.total_seconds() < 3600: # 1 hour (3600 seconds) | |
| debug_logger.log_info(f"[AT_CHECK] Token {token_id}: AT即将过期 (剩余 {time_until_expiry.total_seconds():.0f} 秒),需要刷新") | |
| return await self._refresh_at(token_id) | |
| # AT有效 | |
| return True | |
| async def _refresh_at(self, token_id: int) -> bool: | |
| """内部方法: 刷新AT | |
| Returns: | |
| True if refresh successful, False otherwise | |
| """ | |
| async with self._lock: | |
| token = await self.db.get_token(token_id) | |
| if not token: | |
| return False | |
| try: | |
| debug_logger.log_info(f"[AT_REFRESH] Token {token_id}: 开始刷新AT...") | |
| # 使用ST转AT | |
| result = await self.flow_client.st_to_at(token.st) | |
| new_at = result["access_token"] | |
| expires = result.get("expires") | |
| # 解析过期时间 | |
| new_at_expires = None | |
| if expires: | |
| try: | |
| new_at_expires = datetime.fromisoformat(expires.replace('Z', '+00:00')) | |
| except: | |
| pass | |
| # 更新数据库 | |
| await self.db.update_token( | |
| token_id, | |
| at=new_at, | |
| at_expires=new_at_expires | |
| ) | |
| debug_logger.log_info(f"[AT_REFRESH] Token {token_id}: AT刷新成功") | |
| debug_logger.log_info(f" - 新过期时间: {new_at_expires}") | |
| # 同时刷新credits | |
| try: | |
| credits_result = await self.flow_client.get_credits(new_at) | |
| await self.db.update_token( | |
| token_id, | |
| credits=credits_result.get("credits", 0) | |
| ) | |
| except: | |
| pass | |
| return True | |
| except Exception as e: | |
| debug_logger.log_error(f"[AT_REFRESH] Token {token_id}: AT刷新失败 - {str(e)}") | |
| # 刷新失败,禁用Token | |
| await self.disable_token(token_id) | |
| return False | |
| async def ensure_project_exists(self, token_id: int) -> str: | |
| """确保Token有可用的Project | |
| Returns: | |
| project_id | |
| """ | |
| token = await self.db.get_token(token_id) | |
| if not token: | |
| raise ValueError("Token not found") | |
| # 如果已有project_id,直接返回 | |
| if token.current_project_id: | |
| return token.current_project_id | |
| # 创建新Project | |
| now = datetime.now() | |
| project_name = now.strftime("%b %d - %H:%M") | |
| try: | |
| project_id = await self.flow_client.create_project(token.st, project_name) | |
| debug_logger.log_info(f"[PROJECT] Created project for token {token_id}: {project_name}") | |
| # 更新Token | |
| await self.db.update_token( | |
| token_id, | |
| current_project_id=project_id, | |
| current_project_name=project_name | |
| ) | |
| # 保存Project到数据库 | |
| project = Project( | |
| project_id=project_id, | |
| token_id=token_id, | |
| project_name=project_name | |
| ) | |
| await self.db.add_project(project) | |
| return project_id | |
| except Exception as e: | |
| raise ValueError(f"Failed to create project: {str(e)}") | |
| # ========== Token使用统计 ========== | |
| async def record_usage(self, token_id: int, is_video: bool = False): | |
| """Record token usage""" | |
| await self.db.update_token(token_id, use_count=1, last_used_at=datetime.now()) | |
| if is_video: | |
| await self.db.increment_token_stats(token_id, "video") | |
| else: | |
| await self.db.increment_token_stats(token_id, "image") | |
| async def record_error(self, token_id: int): | |
| """Record token error and auto-disable if threshold reached""" | |
| await self.db.increment_token_stats(token_id, "error") | |
| # Check if should auto-disable token (based on consecutive errors) | |
| stats = await self.db.get_token_stats(token_id) | |
| admin_config = await self.db.get_admin_config() | |
| if stats and stats.consecutive_error_count >= admin_config.error_ban_threshold: | |
| debug_logger.log_warning( | |
| f"[TOKEN_BAN] Token {token_id} consecutive error count ({stats.consecutive_error_count}) " | |
| f"reached threshold ({admin_config.error_ban_threshold}), auto-disabling" | |
| ) | |
| await self.disable_token(token_id) | |
| async def record_success(self, token_id: int): | |
| """Record successful request (reset consecutive error count) | |
| This method resets error_count to 0, which is used for auto-disable threshold checking. | |
| Note: today_error_count and historical statistics are NOT reset. | |
| """ | |
| await self.db.reset_error_count(token_id) | |
| async def ban_token_for_429(self, token_id: int): | |
| """因429错误立即禁用token | |
| Args: | |
| token_id: Token ID | |
| """ | |
| debug_logger.log_warning(f"[429_BAN] 禁用Token {token_id} (原因: 429 Rate Limit)") | |
| await self.db.update_token( | |
| token_id, | |
| is_active=False, | |
| ban_reason="429_rate_limit", | |
| banned_at=datetime.now(timezone.utc) | |
| ) | |
| async def auto_unban_429_tokens(self): | |
| """自动解禁因429被禁用的token | |
| 规则: | |
| - 距离禁用时间12小时后自动解禁 | |
| - 仅解禁未过期的token | |
| - 仅解禁因429被禁用的token | |
| """ | |
| all_tokens = await self.db.get_all_tokens() | |
| now = datetime.now(timezone.utc) | |
| for token in all_tokens: | |
| # 跳过非429禁用的token | |
| if token.ban_reason != "429_rate_limit": | |
| continue | |
| # 跳过未禁用的token | |
| if token.is_active: | |
| continue | |
| # 跳过没有禁用时间的token | |
| if not token.banned_at: | |
| continue | |
| # 检查token是否已过期 | |
| if token.at_expires: | |
| # 确保时区一致 | |
| if token.at_expires.tzinfo is None: | |
| at_expires_aware = token.at_expires.replace(tzinfo=timezone.utc) | |
| else: | |
| at_expires_aware = token.at_expires | |
| # 如果已过期,跳过 | |
| if at_expires_aware <= now: | |
| debug_logger.log_info(f"[AUTO_UNBAN] Token {token.id} 已过期,跳过解禁") | |
| continue | |
| # 确保banned_at时区一致 | |
| if token.banned_at.tzinfo is None: | |
| banned_at_aware = token.banned_at.replace(tzinfo=timezone.utc) | |
| else: | |
| banned_at_aware = token.banned_at | |
| # 检查是否已过12小时 | |
| time_since_ban = now - banned_at_aware | |
| if time_since_ban.total_seconds() >= 12 * 3600: # 12小时 | |
| debug_logger.log_info( | |
| f"[AUTO_UNBAN] 解禁Token {token.id} (禁用时间: {banned_at_aware}, " | |
| f"已过 {time_since_ban.total_seconds() / 3600:.1f} 小时)" | |
| ) | |
| await self.db.update_token( | |
| token.id, | |
| is_active=True, | |
| ban_reason=None, | |
| banned_at=None | |
| ) | |
| # 重置错误计数 | |
| await self.db.reset_error_count(token.id) | |
| # ========== 余额刷新 ========== | |
| async def refresh_credits(self, token_id: int) -> int: | |
| """刷新Token余额 | |
| Returns: | |
| credits | |
| """ | |
| token = await self.db.get_token(token_id) | |
| if not token: | |
| return 0 | |
| # 确保AT有效 | |
| if not await self.is_at_valid(token_id): | |
| return 0 | |
| # 重新获取token (AT可能已刷新) | |
| token = await self.db.get_token(token_id) | |
| try: | |
| result = await self.flow_client.get_credits(token.at) | |
| credits = result.get("credits", 0) | |
| # 更新数据库 | |
| await self.db.update_token(token_id, credits=credits) | |
| return credits | |
| except Exception as e: | |
| debug_logger.log_error(f"Failed to refresh credits for token {token_id}: {str(e)}") | |
| return 0 | |