| import logging |
| import json |
| import os |
| import time |
| import tiktoken |
| from datetime import datetime |
| from typing import Dict, Any, Optional, Tuple |
|
|
| |
| def setup_logging(): |
| """配置日志系统""" |
| log_path = os.environ.get("LOG_PATH", "/tmp/2api.log") |
| log_level_str = os.environ.get("LOG_LEVEL", "INFO").upper() |
| log_level = getattr(logging, log_level_str, logging.INFO) |
| log_format = os.environ.get("LOG_FORMAT", "%(asctime)s - %(name)s - %(levelname)s - %(message)s") |
|
|
| file_handler = logging.FileHandler(log_path, encoding='utf-8') |
| stream_handler = logging.StreamHandler() |
| logging.basicConfig( |
| level=log_level, |
| format=log_format, |
| handlers=[stream_handler, file_handler] |
| ) |
| return logging.getLogger('2api') |
|
|
| logger = setup_logging() |
|
|
| def load_config(): |
| """从 config.json 加载配置(如果存在),否则使用环境变量""" |
| default_config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.json') |
| CONFIG_FILE = os.environ.get("CONFIG_FILE_PATH", default_config_path) |
| config = {} |
|
|
| if os.path.exists(CONFIG_FILE): |
| try: |
| with open(CONFIG_FILE, 'r', encoding='utf-8') as f: |
| config = json.load(f) |
| logger.info(f"已从 {CONFIG_FILE} 加载配置") |
| except (json.JSONDecodeError, IOError) as e: |
| logger.error(f"加载配置文件失败: {e}") |
| config = {} |
| |
| return config |
|
|
| def mask_email(email: str) -> str: |
| """隐藏邮箱中间部分,保护隐私""" |
| if not email or '@' not in email: |
| return "无效邮箱" |
| |
| parts = email.split('@') |
| username = parts[0] |
| domain = parts[1] |
| |
| if len(username) <= 3: |
| masked_username = username[0] + '*' * (len(username) - 1) |
| else: |
| masked_username = username[0] + '*' * (len(username) - 2) + username[-1] |
| |
| return f"{masked_username}@{domain}" |
|
|
| def generate_request_id() -> str: |
| """生成唯一的请求ID""" |
| return f"chatcmpl-{os.urandom(16).hex()}" |
|
|
| def count_tokens(text: str, model: str = "gpt-3.5-turbo") -> int: |
| """ |
| 计算文本的token数量 |
| |
| Args: |
| text: 要计算token数量的文本 |
| model: 模型名称,默认为gpt-3.5-turbo |
| |
| Returns: |
| int: token数量 |
| """ |
| |
| if text is None: |
| text = "" |
| elif not isinstance(text, str): |
| text = str(text) |
| try: |
| |
| if "gpt-4" in model: |
| encoding = tiktoken.encoding_for_model("gpt-4") |
| elif "gpt-3.5" in model: |
| encoding = tiktoken.encoding_for_model("gpt-3.5-turbo") |
| elif "claude" in model: |
| |
| encoding = tiktoken.get_encoding("cl100k_base") |
| else: |
| |
| encoding = tiktoken.get_encoding("cl100k_base") |
| |
| |
| tokens = encoding.encode(text) |
| return len(tokens) |
| except Exception as e: |
| logger.error(f"计算token数量时出错: {e}") |
| |
| return len(text) // 4 |
|
|
| def count_message_tokens(messages: list, model: str = "gpt-3.5-turbo") -> Tuple[int, int, int]: |
| """ |
| 计算OpenAI格式消息列表的token数量 |
| |
| Args: |
| messages: OpenAI格式的消息列表 |
| model: 模型名称,默认为gpt-3.5-turbo |
| |
| Returns: |
| Tuple[int, int, int]: (提示tokens数, 完成tokens数, 总tokens数) |
| """ |
| |
| if messages is None: |
| messages = [] |
| elif not isinstance(messages, list): |
| logger.warning(f"count_message_tokens 收到非列表类型的消息: {type(messages)}") |
| messages = [] |
| |
| prompt_tokens = 0 |
| completion_tokens = 0 |
| |
| try: |
| |
| for message in messages: |
| |
| if not isinstance(message, dict): |
| logger.warning(f"跳过非字典类型的消息: {type(message)}") |
| continue |
| |
| role = message.get('role', '') |
| content = message.get('content', '') |
| |
| if role and content: |
| |
| prompt_tokens += 4 |
| |
| |
| prompt_tokens += 1 |
| |
| |
| prompt_tokens += count_tokens(content, model) |
| |
| |
| if role == 'assistant': |
| completion_tokens += count_tokens(content, model) |
| |
| |
| prompt_tokens += 2 |
| |
| |
| total_tokens = prompt_tokens + completion_tokens |
| |
| return prompt_tokens, completion_tokens, total_tokens |
| except Exception as e: |
| logger.error(f"计算消息token数量时出错: {e}") |
| |
| return 0, 0, 0 |