wulingling / src /key_manager.py
JXJBing's picture
Upload 36 files
9652ab4 verified
import time
import random
import uuid
import json
import os
import logging
import threading
from typing import Dict, List, Optional, Any, Union, Tuple, Callable
# 初始化日志
logger = logging.getLogger("sora-api.key_manager")
class KeyManager:
def __init__(self, storage_file: str = "api_keys.json"):
"""
初始化密钥管理器
Args:
storage_file: 密钥存储文件路径
"""
self.keys = [] # 密钥列表
self.storage_file = storage_file
self.usage_stats = {} # 使用统计
self._lock = threading.RLock() # 添加可重入锁以支持并发访问
self._working_keys = {} # 新增:记录正在工作中的密钥 {key_value: task_id}
self._load_keys()
def _load_keys(self) -> None:
"""从环境变量或文件加载密钥"""
keys_loaded = False
# 先尝试从环境变量加载
api_keys_str = os.getenv("API_KEYS", "")
if api_keys_str:
try:
env_data = json.loads(api_keys_str)
self._process_keys_data(env_data)
if len(self.keys) > 0:
logger.info(f"已从环境变量加载 {len(self.keys)} 个密钥")
keys_loaded = True
else:
logger.warning("环境变量API_KEYS存在但未包含有效密钥")
except json.JSONDecodeError as e:
logger.error(f"解析环境变量API keys失败: {str(e)}")
# 如果环境变量未设置、解析失败或未加载到密钥,从文件加载
if not keys_loaded:
try:
if os.path.exists(self.storage_file):
logger.info(f"尝试从文件加载密钥: {self.storage_file}")
with open(self.storage_file, 'r', encoding='utf-8') as f:
data = json.load(f)
keys_before = len(self.keys)
self._process_keys_data(data)
keys_loaded = len(self.keys) > keys_before
logger.info(f"已从文件加载 {len(self.keys) - keys_before} 个密钥")
else:
logger.warning(f"密钥文件不存在: {self.storage_file}")
except Exception as e:
logger.error(f"加载密钥失败: {str(e)}")
if len(self.keys) == 0:
logger.warning("未能从环境变量或文件加载任何密钥")
def _process_keys_data(self, data):
"""处理不同格式的密钥数据"""
# 处理不同的数据格式
if isinstance(data, list):
# 旧版格式:直接是密钥列表
raw_keys = data
self.keys = []
self.usage_stats = {}
# 为每个密钥创建完整的记录
for key_info in raw_keys:
if isinstance(key_info, dict):
key_value = key_info.get("key")
if not key_value:
logger.warning(f"忽略无效密钥配置: {key_info}")
continue
# 确保有ID
key_id = key_info.get("id") or str(uuid.uuid4())
# 构建完整的密钥记录
key_record = {
"id": key_id,
"name": key_info.get("name", ""),
"key": key_value,
"weight": key_info.get("weight", 1),
"max_rpm": key_info.get("max_rpm", 60),
"requests": 0,
"last_reset": time.time(),
"available": key_info.get("is_enabled", True),
"is_enabled": key_info.get("is_enabled", True),
"created_at": key_info.get("created_at", time.time()),
"last_used": key_info.get("last_used"),
"notes": key_info.get("notes")
}
self.keys.append(key_record)
# 初始化使用统计
self.usage_stats[key_id] = {
"total_requests": 0,
"successful_requests": 0,
"failed_requests": 0,
"daily_usage": {},
"average_response_time": 0
}
elif isinstance(key_info, str):
# 如果是字符串,直接作为密钥值
key_id = str(uuid.uuid4())
self.keys.append({
"id": key_id,
"name": "",
"key": key_info,
"weight": 1,
"max_rpm": 60,
"requests": 0,
"last_reset": time.time(),
"available": True,
"is_enabled": True,
"created_at": time.time(),
"last_used": None,
"notes": None
})
# 初始化使用统计
self.usage_stats[key_id] = {
"total_requests": 0,
"successful_requests": 0,
"failed_requests": 0,
"daily_usage": {},
"average_response_time": 0
}
else:
# 新版格式:包含keys和usage_stats的字典
self.keys = data.get('keys', [])
self.usage_stats = data.get('usage_stats', {})
def _save_keys(self) -> None:
"""保存密钥到文件"""
try:
with open(self.storage_file, 'w', encoding='utf-8') as f:
json.dump({
'keys': self.keys,
'usage_stats': self.usage_stats
}, f, ensure_ascii=False, indent=2)
# 同时更新Config中的API_KEYS
try:
from .config import Config
Config.API_KEYS = self.keys
except (ImportError, AttributeError):
logger.debug("无法更新Config中的API_KEYS")
except Exception as e:
logger.error(f"保存密钥失败: {str(e)}")
def add_key(self, key_value: str, name: str = "", weight: int = 1,
rate_limit: int = 60, is_enabled: bool = True, notes: str = None) -> Dict[str, Any]:
"""
添加密钥
Args:
key_value: 密钥值
name: 密钥名称
weight: 权重
rate_limit: 速率限制(每分钟请求数)
is_enabled: 是否启用
notes: 备注
Returns:
添加的密钥信息
"""
with self._lock: # 使用锁保护添加过程
# 检查密钥是否已存在
for key in self.keys:
if key.get("key") == key_value:
return key
key_id = str(uuid.uuid4())
new_key = {
"id": key_id,
"name": name,
"key": key_value,
"weight": weight,
"max_rpm": rate_limit,
"requests": 0,
"last_reset": time.time(),
"available": is_enabled,
"is_enabled": is_enabled,
"created_at": time.time(),
"last_used": None,
"notes": notes
}
self.keys.append(new_key)
# 初始化使用统计
self.usage_stats[key_id] = {
"total_requests": 0,
"successful_requests": 0,
"failed_requests": 0,
"daily_usage": {},
"average_response_time": 0
}
self._save_keys()
logger.info(f"已添加密钥: {name or key_id}")
return new_key
def get_all_keys(self) -> List[Dict[str, Any]]:
"""获取所有密钥信息(已隐藏完整密钥值)"""
with self._lock: # 使用锁保护读取过程
result = []
for key in self.keys:
key_copy = key.copy()
if "key" in key_copy:
# 只显示密钥前6位和后4位
full_key = key_copy["key"]
if len(full_key) > 10:
key_copy["key"] = full_key[:6] + "..." + full_key[-4:]
# 增加临时禁用信息的处理
if key_copy.get("temp_disabled_until"):
temp_disabled_until = key_copy["temp_disabled_until"]
# 确保temp_disabled_until是时间戳格式
if isinstance(temp_disabled_until, (int, float)):
# 转换为可读格式,但保留原始时间戳,让前端可以自行处理
disabled_until_date = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(temp_disabled_until))
key_copy["temp_disabled_until_formatted"] = disabled_until_date
key_copy["temp_disabled_remaining"] = int(temp_disabled_until - time.time())
result.append(key_copy)
return result
def get_key_by_id(self, key_id: str) -> Optional[Dict[str, Any]]:
"""根据ID获取密钥信息"""
with self._lock: # 使用锁保护读取过程
for key in self.keys:
if key.get("id") == key_id:
return key
return None
def update_key(self, key_id: str, **kwargs) -> Optional[Dict[str, Any]]:
"""
更新密钥信息
Args:
key_id: 密钥ID
**kwargs: 要更新的字段
Returns:
更新后的密钥信息,未找到则返回None
"""
with self._lock: # 使用锁保护更新过程
for key in self.keys:
if key.get("id") == key_id:
# 更新提供的字段
for field, value in kwargs.items():
if value is not None:
if field == "is_enabled":
key["available"] = value # 同步更新available字段
key[field] = value
self._save_keys()
logger.info(f"已更新密钥: {key.get('name') or key_id}")
return key
logger.warning(f"未找到密钥: {key_id}")
return None
def delete_key(self, key_id: str) -> bool:
"""
删除密钥
Args:
key_id: 密钥ID
Returns:
是否成功删除
"""
with self._lock: # 使用锁保护删除过程
original_length = len(self.keys)
self.keys = [key for key in self.keys if key.get("id") != key_id]
# 如果成功删除,保存密钥
if len(self.keys) < original_length:
self._save_keys()
return True
return False
def batch_import_keys(self, keys_data: List[Dict[str, Any]]) -> Dict[str, int]:
"""
批量导入密钥
Args:
keys_data: 密钥数据列表,每个元素为包含密钥信息的字典
Returns:
导入结果统计
"""
with self._lock: # 使用锁保护导入过程
imported_count = 0
skipped_count = 0
# 获取现有密钥值
existing_keys = {key.get("key") for key in self.keys}
for key_data in keys_data:
key_value = key_data.get("key")
if not key_value:
continue
# 检查密钥是否已存在
if key_value in existing_keys:
skipped_count += 1
continue
# 添加新密钥
key_id = str(uuid.uuid4())
new_key = {
"id": key_id,
"name": key_data.get("name", ""),
"key": key_value,
"weight": key_data.get("weight", 1),
"max_rpm": key_data.get("rate_limit", 60),
"requests": 0,
"last_reset": time.time(),
"available": key_data.get("enabled", True),
"is_enabled": key_data.get("enabled", True),
"created_at": time.time(),
"last_used": None,
"notes": key_data.get("notes")
}
self.keys.append(new_key)
existing_keys.add(key_value) # 添加到已存在集合中
# 初始化使用统计
self.usage_stats[key_id] = {
"total_requests": 0,
"successful_requests": 0,
"failed_requests": 0,
"daily_usage": {},
"average_response_time": 0
}
imported_count += 1
# 保存密钥
if imported_count > 0:
self._save_keys()
return {
"imported": imported_count,
"skipped": skipped_count
}
def get_key(self) -> Optional[str]:
"""获取下一个可用的密钥"""
with self._lock: # 使用锁保护整个获取密钥过程
if not self.keys:
logger.warning("没有可用的密钥")
return None
# 重置计数器(如果需要)
current_time = time.time()
temporary_disabled_updated = False
for key in self.keys:
# 检查是否有被临时禁用的密钥需要重新启用
if key.get("temp_disabled_until") and current_time > key.get("temp_disabled_until"):
key["is_enabled"] = True
key["available"] = True
key["temp_disabled_until"] = None
temporary_disabled_updated = True
logger.info(f"密钥 {key.get('name') or key.get('id')} 的临时禁用已解除")
if current_time - key["last_reset"] >= 60:
key["requests"] = 0
key["last_reset"] = current_time
if not key.get("temp_disabled_until"): # 只有未被临时禁用的密钥才会被重新激活
key["available"] = key.get("is_enabled", True)
# 如果有任何临时禁用的密钥被更新,保存变更
if temporary_disabled_updated:
self._save_keys()
# 筛选可用的密钥,排除工作中的密钥
available_keys = []
for k in self.keys:
key_value = k.get("key", "")
clean_key = key_value.replace("Bearer ", "") if key_value.startswith("Bearer ") else key_value
# 检查此密钥是否在工作中
is_working = clean_key in self._working_keys
if k.get("available", False) and not is_working:
available_keys.append(k)
if not available_keys:
logger.warning("没有可用的密钥(所有密钥都达到速率限制、被禁用或正在工作中)")
return None
# 根据权重选择密钥
weights = [k.get("weight", 1) for k in available_keys]
selected_idx = random.choices(range(len(available_keys)), weights=weights, k=1)[0]
selected_key = available_keys[selected_idx]
# 更新使用统计
selected_key["requests"] += 1
selected_key["last_used"] = current_time
# 检查是否达到速率限制
if selected_key["requests"] >= selected_key.get("max_rpm", 60):
selected_key["available"] = False
# 保存数据 - 并发环境下调整为每次都保存,避免状态不一致
# 原来是随机保存(10%的概率)
self._save_keys()
# 确保返回的密钥包含"Bearer "前缀
key_value = selected_key["key"]
if not key_value.startswith("Bearer "):
key_value = f"Bearer {key_value}"
return key_value
def record_request_result(self, key: str, success: bool, response_time: float = 0) -> None:
"""
记录请求结果
Args:
key: 密钥值
success: 请求是否成功
response_time: 响应时间(秒)
"""
if not key:
logger.warning("记录请求结果失败:密钥为空")
return
with self._lock: # 使用锁保护记录过程
# 去掉可能的Bearer前缀
key_for_search = key.replace("Bearer ", "") if key.startswith("Bearer ") else key
# 查找对应的密钥ID
key_id = None
key_info = None
for k in self.keys:
stored_key = k.get("key", "").replace("Bearer ", "") if k.get("key", "").startswith("Bearer ") else k.get("key", "")
if stored_key == key_for_search:
key_id = k.get("id")
key_info = k
break
if not key_id:
logger.warning(f"记录请求结果失败:未找到密钥 {key_for_search[:6]}...")
return
# 初始化usage_stats如果该密钥还没有统计数据
if key_id not in self.usage_stats:
self.usage_stats[key_id] = {
"total_requests": 0,
"successful_requests": 0,
"failed_requests": 0,
"daily_usage": {},
"average_response_time": 0
}
# 记录请求结果
stats = self.usage_stats[key_id]
stats["total_requests"] += 1
if success:
stats["successful_requests"] += 1
else:
stats["failed_requests"] += 1
# 记录响应时间
if response_time > 0:
if stats["average_response_time"] == 0:
stats["average_response_time"] = response_time
else:
# 使用加权平均
old_avg = stats["average_response_time"]
total = stats["total_requests"]
# 避免 total 为0或1时产生问题,尽管前面 total_requests 已经增加了
if total > 0:
stats["average_response_time"] = ((old_avg * (total - 1)) + response_time) / total
else: # 理论上不应该发生,因为 total_requests 已经增加了
stats["average_response_time"] = response_time
# 记录每日使用情况
today = time.strftime("%Y-%m-%d")
if today not in stats["daily_usage"]:
stats["daily_usage"][today] = {"successful": 0, "failed": 0} # 初始化每日统计
# 根据成功与否更新每日统计
if success:
stats["daily_usage"][today]["successful"] += 1
else:
stats["daily_usage"][today]["failed"] += 1
# 保留最近30天的数据
if len(stats["daily_usage"]) > 30:
# 获取所有日期并排序,然后删除最早的
sorted_dates = sorted(stats["daily_usage"].keys())
if sorted_dates: # 确保列表不为空
oldest_date = sorted_dates[0]
del stats["daily_usage"][oldest_date]
# 更新密钥的最后使用时间
if key_info and "last_used" in key_info:
key_info["last_used"] = time.time()
# 并发环境下每次都保存,确保统计准确性
self._save_keys()
def get_usage_stats(self) -> Dict[str, Any]:
"""获取使用统计信息"""
with self._lock: # 使用锁保护读取过程
total_keys = len(self.keys)
active_keys = sum(1 for k in self.keys if k.get("is_enabled", False))
available_keys = sum(1 for k in self.keys if k.get("available", False))
total_requests = sum(stats.get("total_requests", 0) for stats in self.usage_stats.values())
successful_requests = sum(stats.get("successful_requests", 0) for stats in self.usage_stats.values())
# 计算成功率
success_rate = (successful_requests / total_requests * 100) if total_requests > 0 else 0
# 计算每个密钥的平均响应时间
avg_response_times = [stats.get("average_response_time", 0) for stats in self.usage_stats.values() if stats.get("average_response_time", 0) > 0]
overall_avg_response_time = sum(avg_response_times) / len(avg_response_times) if avg_response_times else 0
# 获取过去7天的使用情况
past_7_days = {}
for key_id, stats in self.usage_stats.items():
daily_usage = stats.get("daily_usage", {})
for date, count_data in daily_usage.items():
if date not in past_7_days:
past_7_days[date] = {"successful": 0, "failed": 0}
# 正确处理字典类型的count_data
past_7_days[date]["successful"] += count_data.get("successful", 0)
past_7_days[date]["failed"] += count_data.get("failed", 0)
# 只保留最近7天
dates = sorted(past_7_days.keys(), reverse=True)[:7]
past_7_days = {date: past_7_days[date] for date in dates}
return {
"total_keys": total_keys,
"active_keys": active_keys,
"available_keys": available_keys,
"total_requests": total_requests,
"successful_requests": successful_requests,
"failed_requests": total_requests - successful_requests,
"success_rate": success_rate,
"average_response_time": overall_avg_response_time,
"past_7_days": past_7_days
}
def mark_key_as_working(self, key: str, task_id: str) -> None:
"""
将密钥标记为工作中状态
Args:
key: API密钥值(可能包含Bearer前缀)
task_id: 关联的任务ID
"""
with self._lock:
clean_key = key.replace("Bearer ", "") if key.startswith("Bearer ") else key
self._working_keys[clean_key] = task_id
logger.debug(f"密钥已标记为工作中,关联任务ID: {task_id}")
def release_key(self, key: str) -> None:
"""
释放工作中的密钥
Args:
key: API密钥值(可能包含Bearer前缀)
"""
with self._lock:
clean_key = key.replace("Bearer ", "") if key.startswith("Bearer ") else key
if clean_key in self._working_keys:
del self._working_keys[clean_key]
logger.debug(f"密钥已释放")
def is_key_working(self, key: str) -> bool:
"""
检查密钥是否正在工作中
Args:
key: API密钥值(可能包含Bearer前缀)
Returns:
bool: 是否在工作中
"""
with self._lock:
clean_key = key.replace("Bearer ", "") if key.startswith("Bearer ") else key
return clean_key in self._working_keys
def mark_key_invalid(self, key: str) -> Optional[str]:
"""
将指定的密钥标记为无效(临时禁用而不是永久禁用),并返回一个新的可用密钥
Args:
key: API密钥值(可能包含Bearer前缀)
Returns:
Optional[str]: 新的可用密钥,如果没有可用密钥则返回None
"""
# 调用临时禁用方法,设置24小时禁用时间
return self.mark_key_temp_disabled(key, hours=24.0)
def mark_key_temp_disabled(self, key: str, hours: float = 12.0) -> Optional[str]:
"""
将指定的密钥临时禁用指定小时数,并返回一个新的可用密钥
Args:
key: API密钥值(可能包含Bearer前缀)
hours: 禁用小时数
Returns:
Optional[str]: 新的可用密钥,如果没有可用密钥则返回None
"""
with self._lock: # 使用锁保护临时禁用过程
# 去掉可能的Bearer前缀
key_for_search = key.replace("Bearer ", "") if key.startswith("Bearer ") else key
# 检查是否是因为密钥在工作中导致的错误
if key_for_search in self._working_keys:
logger.warning(f"尝试禁用正在工作中的密钥(任务ID: {self._working_keys[key_for_search]}),跳过禁用操作")
# 获取一个新密钥返回,但不禁用当前密钥
new_key = self.get_key()
if new_key:
logger.info(f"已返回新密钥,但未禁用工作中的密钥")
return new_key
else:
logger.warning("没有可用的备用密钥")
return None
# 查找对应的密钥
key_found = False
disabled_key_id = None
for key_info in self.keys:
stored_key = key_info.get("key", "").replace("Bearer ", "") if key_info.get("key", "").startswith("Bearer ") else key_info.get("key", "")
if stored_key == key_for_search:
# 标记密钥为临时禁用
disabled_until = time.time() + (hours * 3600) # 当前时间加上禁用小时数
key_info["available"] = False
key_info["temp_disabled_until"] = disabled_until
key_info["notes"] = (key_info.get("notes") or "") + f"\n[自动] 在 {time.strftime('%Y-%m-%d %H:%M:%S')} 被临时禁用{hours}小时"
key_found = True
disabled_key_id = key_info.get("id")
logger.warning(f"密钥 {key_info.get('name') or key_info.get('id')} 被临时禁用{hours}小时")
break
if key_found:
# 保存更改
self._save_keys()
# 获取新的密钥,排除已禁用的
new_key = self.get_key()
if new_key:
logger.info(f"已自动切换到新的密钥")
return new_key
else:
logger.warning("没有可用的备用密钥")
return None
else:
logger.warning(f"未找到要临时禁用的密钥")
return None
def retry_request(self, original_key: str, request_func: Callable, max_retries: int = 1,
max_key_switches: int = 3) -> Tuple[bool, Any, str]:
"""
出错时自动重试请求,并在需要时切换密钥
Args:
original_key: 原始API密钥(可能包含Bearer前缀)
request_func: 执行请求的函数,接受一个参数(密钥)并返回(成功标志, 结果)
max_retries: 使用同一密钥的最大重试次数
max_key_switches: 最大密钥切换次数
Returns:
Tuple[bool, Any, str]: (是否成功, 请求结果, 使用的密钥)
"""
current_key = original_key
current_key_switches = 0
# 首先用原始密钥尝试
for attempt in range(max_retries + 1): # +1是因为第一次不算重试
try:
success, result = request_func(current_key)
# 成功的请求不应该导致密钥被禁用
if success:
# 记录请求成功,避免不必要的密钥禁用
with self._lock:
self.record_request_result(current_key, True)
return True, result, current_key
logger.warning(f"请求失败(尝试 {attempt+1}/{max_retries+1}): {result}")
except Exception as e:
logger.error(f"请求异常(尝试 {attempt+1}/{max_retries+1}): {str(e)}")
# 如果这不是最后一次尝试,等待一秒后重试
if attempt < max_retries:
time.sleep(1)
# 如果原始密钥的所有重试都失败,尝试切换密钥
tried_keys = set([current_key.replace("Bearer ", "") if current_key.startswith("Bearer ") else current_key])
while current_key_switches < max_key_switches:
# 获取新的密钥
with self._lock:
new_key = self.get_key()
if not new_key:
logger.warning("没有更多可用的密钥")
break
# 确保不使用已经尝试过的密钥
clean_new_key = new_key.replace("Bearer ", "") if new_key.startswith("Bearer ") else new_key
if clean_new_key in tried_keys:
continue
tried_keys.add(clean_new_key)
current_key = new_key
current_key_switches += 1
logger.info(f"切换到新密钥 (切换 {current_key_switches}/{max_key_switches})")
# 用新密钥尝试
for attempt in range(max_retries + 1):
try:
success, result = request_func(current_key)
if success:
# 记录请求成功
with self._lock:
self.record_request_result(current_key, True)
return True, result, current_key
logger.warning(f"使用新密钥请求失败(尝试 {attempt+1}/{max_retries+1}): {result}")
except Exception as e:
logger.error(f"使用新密钥请求异常(尝试 {attempt+1}/{max_retries+1}): {str(e)}")
# 如果这不是最后一次尝试,等待一秒后重试
if attempt < max_retries:
time.sleep(1)
# 所有尝试都失败,临时禁用原始密钥
# 但是在并发环境下,这可能是因为网络或服务问题,而非密钥问题
# 增加额外检查以减少不必要的密钥禁用
should_disable = True
# 在临时禁用前,确认是否是密钥问题而非服务或网络问题
# 此处可以添加额外逻辑来判断是否应该禁用密钥
if should_disable:
logger.error(f"所有重试和密钥切换尝试都失败,临时禁用原始密钥")
with self._lock:
self.mark_key_temp_disabled(original_key, hours=6.0) # 减少禁用时间,避免资源浪费
else:
logger.warning(f"所有重试和密钥切换尝试都失败,但可能是服务问题而非密钥问题,不禁用密钥")
# 返回最后一次尝试的结果
return False, result, current_key
# 创建全局密钥管理器实例
storage_file = os.getenv("KEYS_STORAGE_FILE", "api_keys.json")
# 如果提供了绝对路径则直接使用,否则使用相对路径
if not os.path.isabs(storage_file):
base_dir = os.getenv("BASE_DIR", os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
storage_file = os.path.join(base_dir, storage_file)
key_manager = KeyManager(storage_file=storage_file)
logger.info(f"初始化全局密钥管理器,存储文件: {storage_file}")