| import os |
| import json |
| import asyncio |
| import aiohttp |
| import traceback |
| from fastapi import FastAPI, Request, HTTPException |
| from fastapi.responses import StreamingResponse, JSONResponse |
| from fastapi.middleware.cors import CORSMiddleware |
| import uvicorn |
| from typing import Dict, Any, AsyncGenerator, List, Union |
| import logging |
| import base64 |
| import mimetypes |
|
|
| |
| logging.basicConfig( |
| level=logging.INFO, |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
| ) |
| logger = logging.getLogger(__name__) |
|
|
| app = FastAPI( |
| title="Replicate API Proxy for LobeChat", |
| description="A proxy service to forward Replicate API requests in OpenAI-compatible format", |
| version="1.0.0" |
| ) |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| REPLICATE_API_TOKEN = os.getenv("REPLICATE_API_TOKEN") |
| if not REPLICATE_API_TOKEN: |
| logger.error("REPLICATE_API_TOKEN not found in environment variables") |
|
|
| |
| IMGBB_API_KEY = "78f0c4360135e80c46b24b44e1e20a20" |
| IMGBB_API_URL = "https://api.imgbb.com/1/upload" |
|
|
| |
| REPLICATE_BASE_URL = "https://api.replicate.com/v1" |
| DEFAULT_MODEL = "anthropic/claude-3.5-sonnet" |
|
|
| |
| SUPPORTED_TEXT_EXTENSIONS = { |
| '.txt', '.md', '.py', '.js', '.ts', '.html', '.htm', '.css', '.json', |
| '.xml', '.yaml', '.yml', '.sh', '.bash', '.zsh', '.fish', '.ps1', |
| '.java', '.c', '.cpp', '.cc', '.cxx', '.h', '.hpp', '.cs', '.php', |
| '.rb', '.go', '.rs', '.swift', '.kt', '.scala', '.r', '.sql', |
| '.dockerfile', '.gitignore', '.gitattributes', '.env', '.ini', '.conf', |
| '.log', '.csv', '.tsv', '.properties', '.toml', '.lock' |
| } |
|
|
| SUPPORTED_IMAGE_EXTENSIONS = { |
| '.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp', '.svg' |
| } |
|
|
| |
| MODEL_CONFIGS = { |
| "anthropic/claude-4-sonnet": { |
| "min_max_tokens": 1024, |
| "default_max_tokens": 8192, |
| "has_max_tokens_limit": True, |
| "supports_vision": True, |
| "supports_files": True, |
| "image_format": "url" |
| }, |
| "anthropic/claude-3.5-sonnet": { |
| "min_max_tokens": 1, |
| "default_max_tokens": 8192, |
| "has_max_tokens_limit": False, |
| "supports_vision": True, |
| "supports_files": True, |
| "image_format": "data_url" |
| }, |
| "anthropic/claude-3-sonnet": { |
| "min_max_tokens": 1, |
| "default_max_tokens": 4096, |
| "has_max_tokens_limit": False, |
| "supports_vision": True, |
| "supports_files": True, |
| "image_format": "data_url" |
| }, |
| "anthropic/claude-3.5-haiku": { |
| "min_max_tokens": 1, |
| "default_max_tokens": 4096, |
| "has_max_tokens_limit": False, |
| "supports_vision": True, |
| "supports_files": True, |
| "image_format": "data_url" |
| }, |
| "anthropic/claude-3-haiku": { |
| "min_max_tokens": 1, |
| "default_max_tokens": 4096, |
| "has_max_tokens_limit": False, |
| "supports_vision": True, |
| "supports_files": True, |
| "image_format": "data_url" |
| }, |
| "google/gemini-2.5-pro": { |
| "min_max_tokens": 1, |
| "default_max_tokens": 8192, |
| "has_max_tokens_limit": False, |
| "supports_vision": True, |
| "supports_files": True, |
| "image_format": "data_url" |
| } |
| } |
|
|
| |
| @app.exception_handler(Exception) |
| async def global_exception_handler(request: Request, exc: Exception): |
| logger.error(f"Global exception: {str(exc)}") |
| logger.error(f"Traceback: {traceback.format_exc()}") |
| return JSONResponse( |
| status_code=500, |
| content={ |
| "error": { |
| "message": f"Internal server error: {str(exc)}", |
| "type": "internal_error" |
| } |
| } |
| ) |
|
|
| def get_file_extension(filename: str) -> str: |
| """获取文件扩展名""" |
| return os.path.splitext(filename.lower())[1] |
|
|
| def decode_base64_file(data_url: str) -> tuple[str, str, str]: |
| """ |
| 解码 base64 文件数据 |
| 返回: (mime_type, filename, content) |
| """ |
| try: |
| if not data_url.startswith("data:"): |
| return None, None, None |
| |
| |
| header, base64_content = data_url.split(",", 1) |
| header_parts = header.split(";") |
| |
| mime_type = header_parts[0].replace("data:", "") |
| filename = "unknown_file" |
| |
| |
| for part in header_parts: |
| if part.startswith("name="): |
| filename = part.replace("name=", "") |
| break |
| |
| |
| try: |
| decoded_bytes = base64.b64decode(base64_content) |
| |
| |
| if mime_type.startswith("text/") or any(filename.lower().endswith(ext) for ext in SUPPORTED_TEXT_EXTENSIONS): |
| try: |
| content = decoded_bytes.decode('utf-8') |
| return mime_type, filename, content |
| except UnicodeDecodeError: |
| try: |
| content = decoded_bytes.decode('latin-1') |
| return mime_type, filename, content |
| except UnicodeDecodeError: |
| logger.warning(f"Failed to decode text file {filename}") |
| return mime_type, filename, None |
| else: |
| |
| return mime_type, filename, base64_content |
| |
| except Exception as e: |
| logger.error(f"Failed to decode base64 content: {e}") |
| return mime_type, filename, None |
| |
| except Exception as e: |
| logger.error(f"Failed to parse data URL: {e}") |
| return None, None, None |
|
|
| async def download_image_from_url(url: str) -> str: |
| """ |
| 从URL下载图片并转换为base64 |
| 返回base64编码的图片数据 |
| """ |
| try: |
| logger.info(f"Downloading image from URL: {url}") |
| |
| async with aiohttp.ClientSession() as session: |
| async with session.get(url, timeout=30) as response: |
| if response.status == 200: |
| image_bytes = await response.read() |
| |
| |
| content_type = response.headers.get('content-type', '') |
| if not content_type.startswith('image/'): |
| |
| if url.lower().endswith(('.jpg', '.jpeg')): |
| content_type = 'image/jpeg' |
| elif url.lower().endswith('.png'): |
| content_type = 'image/png' |
| elif url.lower().endswith('.gif'): |
| content_type = 'image/gif' |
| elif url.lower().endswith('.webp'): |
| content_type = 'image/webp' |
| else: |
| content_type = 'image/jpeg' |
| |
| |
| base64_data = base64.b64encode(image_bytes).decode('utf-8') |
| data_url = f"data:{content_type};base64,{base64_data}" |
| |
| logger.info(f"Successfully downloaded image, size: {len(image_bytes)} bytes, base64 size: {len(base64_data)} chars") |
| return data_url |
| else: |
| logger.error(f"Failed to download image: HTTP {response.status}") |
| return None |
| |
| except asyncio.TimeoutError: |
| logger.error(f"Timeout downloading image from {url}") |
| return None |
| except Exception as e: |
| logger.error(f"Error downloading image from {url}: {e}") |
| return None |
|
|
| async def upload_image_to_imgbb(base64_data: str) -> str: |
| """ |
| 将 base64 图片上传到 imgbb |
| 返回图片的 URL |
| """ |
| try: |
| |
| if base64_data.startswith("data:"): |
| base64_content = base64_data.split(",")[1] |
| else: |
| base64_content = base64_data |
| |
| |
| data = { |
| 'key': IMGBB_API_KEY, |
| 'image': base64_content, |
| 'expiration': 300 |
| } |
| |
| logger.info(f"Uploading image to imgbb, size: {len(base64_content)} chars") |
| |
| |
| async with aiohttp.ClientSession() as session: |
| async with session.post(IMGBB_API_URL, data=data, timeout=30) as response: |
| if response.status == 200: |
| result = await response.json() |
| if result.get('success'): |
| image_url = result['data']['url'] |
| logger.info(f"Image uploaded successfully: {image_url}") |
| return image_url |
| else: |
| logger.error(f"imgbb upload failed: {result}") |
| return None |
| else: |
| error_text = await response.text() |
| logger.error(f"imgbb upload error: {response.status} - {error_text}") |
| return None |
| |
| except asyncio.TimeoutError: |
| logger.error("Timeout uploading image to imgbb") |
| return None |
| except Exception as e: |
| logger.error(f"Failed to upload image to imgbb: {e}") |
| return None |
|
|
| async def format_image_for_model(base64_data: str, model_config: Dict[str, Any]) -> str: |
| """ |
| 根据模型配置格式化图片数据 |
| """ |
| image_format = model_config.get("image_format", "data_url") |
| |
| if image_format == "url": |
| |
| image_url = await upload_image_to_imgbb(base64_data) |
| if image_url: |
| return image_url |
| else: |
| logger.error("Failed to upload image, falling back to data URL") |
| |
| return format_image_as_data_url(base64_data) |
| |
| elif image_format == "data_url": |
| return format_image_as_data_url(base64_data) |
| |
| return base64_data |
|
|
| def format_image_as_data_url(base64_data: str) -> str: |
| """ |
| 将 base64 数据格式化为 data URL |
| """ |
| |
| if base64_data.startswith("data:"): |
| return base64_data |
| |
| |
| try: |
| |
| decoded_bytes = base64.b64decode(base64_data[:100]) |
| |
| if decoded_bytes.startswith(b'\xff\xd8\xff'): |
| |
| return f"data:image/jpeg;base64,{base64_data}" |
| elif decoded_bytes.startswith(b'\x89PNG\r\n\x1a\n'): |
| |
| return f"data:image/png;base64,{base64_data}" |
| elif decoded_bytes.startswith(b'GIF87a') or decoded_bytes.startswith(b'GIF89a'): |
| |
| return f"data:image/gif;base64,{base64_data}" |
| elif decoded_bytes.startswith(b'RIFF') and b'WEBP' in decoded_bytes[:20]: |
| |
| return f"data:image/webp;base64,{base64_data}" |
| else: |
| |
| return f"data:image/jpeg;base64,{base64_data}" |
| except Exception as e: |
| logger.warning(f"Failed to detect image format: {e}, using JPEG as default") |
| return f"data:image/jpeg;base64,{base64_data}" |
|
|
| def extract_images_from_context(content: str) -> List[str]: |
| """ |
| 从系统上下文中提取图片URL |
| """ |
| images = [] |
| try: |
| |
| import re |
| pattern = r'<image[^>]+url="([^"]+)"[^>]*></image>' |
| matches = re.findall(pattern, content) |
| for url in matches: |
| if url.startswith('http'): |
| images.append(url) |
| logger.info(f"Found image URL in context: {url}") |
| except Exception as e: |
| logger.error(f"Error extracting images from context: {e}") |
| |
| return images |
|
|
| def extract_content_from_message(message: Dict[str, Any]) -> tuple[str, List[str], List[Dict[str, str]]]: |
| """ |
| 从消息中提取文本内容、图片和文件 |
| 返回: (text_content, image_data_list, file_data_list) |
| """ |
| content = message.get("content", "") |
| images = [] |
| files = [] |
| |
| if isinstance(content, str): |
| |
| context_images = extract_images_from_context(content) |
| if context_images: |
| images.extend(context_images) |
| return content, images, files |
| elif isinstance(content, list): |
| |
| text_parts = [] |
| |
| for item in content: |
| if isinstance(item, dict): |
| item_type = item.get("type", "") |
| |
| if item_type == "text": |
| text_content = item.get("text", "") |
| text_parts.append(text_content) |
| |
| context_images = extract_images_from_context(text_content) |
| if context_images: |
| images.extend(context_images) |
| |
| elif item_type == "image_url": |
| image_url = item.get("image_url", {}) |
| url = image_url.get("url", "") |
| |
| if url.startswith("data:image/"): |
| |
| try: |
| if ";base64," in url: |
| base64_data = url.split(";base64,")[1] |
| |
| images.append(url) |
| logger.info(f"Found base64 image, size: {len(base64_data)} chars") |
| else: |
| logger.warning(f"Image URL format not supported: {url[:100]}...") |
| except Exception as e: |
| logger.error(f"Error processing image: {e}") |
| elif url.startswith("http"): |
| |
| images.append(url) |
| logger.info(f"Found external image URL: {url}") |
| else: |
| logger.warning(f"Unsupported image URL format: {url}") |
| |
| elif item_type == "file" or (item_type == "image_url" and not item.get("image_url", {}).get("url", "").startswith("data:image/")): |
| |
| file_url = item.get("image_url", {}).get("url", "") if item_type == "image_url" else item.get("file_url", {}).get("url", "") |
| |
| if file_url.startswith("data:"): |
| mime_type, filename, file_content = decode_base64_file(file_url) |
| |
| if file_content is not None: |
| file_ext = get_file_extension(filename) |
| |
| if file_ext in SUPPORTED_IMAGE_EXTENSIONS and mime_type.startswith("image/"): |
| |
| images.append(file_url) |
| logger.info(f"Found image file: {filename}") |
| elif file_ext in SUPPORTED_TEXT_EXTENSIONS or mime_type.startswith("text/"): |
| |
| files.append({ |
| "filename": filename, |
| "content": file_content, |
| "mime_type": mime_type |
| }) |
| logger.info(f"Found text file: {filename}, size: {len(file_content)} chars") |
| else: |
| logger.warning(f"Unsupported file type: {filename} ({mime_type})") |
| |
| elif isinstance(item, str): |
| text_parts.append(item) |
| |
| context_images = extract_images_from_context(item) |
| if context_images: |
| images.extend(context_images) |
| |
| return " ".join(text_parts), images, files |
| |
| return str(content), images, files |
|
|
| def format_files_for_prompt(files: List[Dict[str, str]]) -> str: |
| """将文件内容格式化为提示文本""" |
| if not files: |
| return "" |
| |
| file_sections = [] |
| for file_data in files: |
| filename = file_data["filename"] |
| content = file_data["content"] |
| mime_type = file_data.get("mime_type", "text/plain") |
| |
| |
| max_length = 10000 |
| if len(content) > max_length: |
| content = content[:max_length] + "\n\n[文件内容已截断,显示前 10000 字符]" |
| |
| file_section = f""" |
| |
| --- 文件: {filename} ({mime_type}) --- |
| {content} |
| --- 文件结束 --- |
| """ |
| file_sections.append(file_section) |
| |
| return "\n".join(file_sections) |
|
|
| async def transform_openai_to_replicate(openai_request: Dict[str, Any], model_override: str = None) -> Dict[str, Any]: |
| """将OpenAI格式的请求转换为Replicate格式""" |
| try: |
| messages = openai_request.get("messages", []) |
| |
| |
| system_prompt = None |
| user_messages = [] |
| has_images = False |
| has_files = False |
| all_files = [] |
| primary_image = None |
| |
| for message in messages: |
| if message.get("role") == "system": |
| system_prompt = message.get("content", "") |
| elif message.get("role") in ["user", "assistant"]: |
| |
| text_content, image_list, file_list = extract_content_from_message(message) |
| |
| |
| msg_data = { |
| "role": message.get("role"), |
| "content": text_content, |
| "images": image_list, |
| "files": file_list |
| } |
| user_messages.append(msg_data) |
| |
| if image_list: |
| has_images = True |
| |
| if message.get("role") == "user": |
| primary_image = image_list[0] |
| |
| if file_list: |
| has_files = True |
| all_files.extend(file_list) |
| |
| |
| model = model_override or openai_request.get("model", DEFAULT_MODEL) |
| |
| |
| model_mapping = { |
| "claude-4-sonnet": "anthropic/claude-4-sonnet", |
| "claude-3.5-sonnet": "anthropic/claude-3.5-sonnet", |
| "claude-3-sonnet": "anthropic/claude-3-sonnet", |
| "claude-3.5-haiku": "anthropic/claude-3.5-haiku", |
| "claude-3-haiku": "anthropic/claude-3-haiku", |
| "gemini-2.5-pro": "google/gemini-2.5-pro", |
| } |
| |
| if model in model_mapping: |
| model = model_mapping[model] |
| elif not model.startswith(("anthropic/", "google/")): |
| model = "anthropic/claude-3.5-sonnet" |
| |
| |
| model_config = MODEL_CONFIGS.get(model, MODEL_CONFIGS["anthropic/claude-3.5-sonnet"]) |
| |
| |
| if has_images and not model_config.get("supports_vision", False): |
| logger.warning(f"Model {model} may not support vision") |
| if has_files and not model_config.get("supports_files", False): |
| logger.warning(f"Model {model} may not support file processing") |
| |
| |
| formatted_image = None |
| if has_images and primary_image: |
| logger.info(f"Processing image for model {model} with format {model_config.get('image_format')}") |
| |
| |
| if primary_image.startswith("http"): |
| logger.info(f"Downloading external image: {primary_image}") |
| downloaded_image = await download_image_from_url(primary_image) |
| if downloaded_image: |
| primary_image = downloaded_image |
| logger.info("External image downloaded and converted to base64") |
| else: |
| logger.error("Failed to download external image") |
| primary_image = None |
| |
| if primary_image: |
| formatted_image = await format_image_for_model(primary_image, model_config) |
| |
| if not formatted_image: |
| logger.error("Failed to format image for model") |
| raise HTTPException(status_code=500, detail="Failed to process image") |
| |
| |
| replicate_input = {} |
| |
| |
| prompt_parts = [] |
| |
| |
| if has_files: |
| files_section = format_files_for_prompt(all_files) |
| if files_section: |
| prompt_parts.append("以下是用户上传的文件内容:") |
| prompt_parts.append(files_section) |
| prompt_parts.append("请根据上述文件内容回答用户的问题。") |
| |
| |
| for i, msg in enumerate(user_messages): |
| role = msg["role"] |
| content = msg["content"] |
| |
| if role == "user": |
| prompt_parts.append(f"Human: {content}") |
| elif role == "assistant": |
| prompt_parts.append(f"Assistant: {content}") |
| |
| |
| prompt = "\n\n".join(prompt_parts) |
| if not prompt.endswith("\n\nAssistant:"): |
| prompt += "\n\nAssistant:" |
| |
| replicate_input["prompt"] = prompt |
| |
| |
| if formatted_image: |
| replicate_input["image"] = formatted_image |
| if formatted_image.startswith("http"): |
| logger.info(f"Added image URL to request for model {model}: {formatted_image}") |
| else: |
| logger.info(f"Added image data to request for model {model}: {formatted_image[:100]}...") |
| |
| |
| if system_prompt: |
| replicate_input["system_prompt"] = system_prompt |
| |
| |
| client_max_tokens = openai_request.get("max_tokens") |
| |
| if client_max_tokens is not None: |
| max_tokens = client_max_tokens |
| if max_tokens < model_config["min_max_tokens"]: |
| logger.info(f"Adjusting max_tokens from {max_tokens} to {model_config['min_max_tokens']} (model minimum)") |
| max_tokens = model_config["min_max_tokens"] |
| else: |
| if model_config["has_max_tokens_limit"]: |
| max_tokens = model_config["default_max_tokens"] |
| logger.info(f"Using default max_tokens {max_tokens} for model {model}") |
| else: |
| max_tokens = None |
| logger.info(f"No max_tokens limit for model {model}, allowing unlimited") |
| |
| if max_tokens is not None: |
| replicate_input["max_tokens"] = max_tokens |
| |
| |
| if "temperature" in openai_request: |
| replicate_input["temperature"] = openai_request["temperature"] |
| |
| if "top_p" in openai_request: |
| replicate_input["top_p"] = openai_request["top_p"] |
| |
| if "frequency_penalty" in openai_request: |
| replicate_input["frequency_penalty"] = openai_request["frequency_penalty"] |
| |
| if "presence_penalty" in openai_request: |
| replicate_input["presence_penalty"] = openai_request["presence_penalty"] |
| |
| replicate_request = { |
| "stream": openai_request.get("stream", False), |
| "input": replicate_input |
| } |
| |
| logger.info(f"Transformed request for model: {model}") |
| logger.info(f"Message count: {len(messages)} (system: {1 if system_prompt else 0}, user/assistant: {len(user_messages)})") |
| logger.info(f"Has images: {has_images}, Has files: {has_files}") |
| if has_files: |
| logger.info(f"Files: {[f['filename'] for f in all_files]}") |
| logger.info(f"Parameters: max_tokens={max_tokens}, temperature={replicate_input.get('temperature', 'not set')}") |
| |
| return replicate_request, model |
| |
| except Exception as e: |
| logger.error(f"Error transforming request: {str(e)}") |
| raise HTTPException(status_code=400, detail=f"Request transformation error: {str(e)}") |
|
|
| def create_log_safe_data(data: Dict[str, Any]) -> Dict[str, Any]: |
| """创建用于日志记录的安全数据副本,不修改原始数据""" |
| log_data = json.loads(json.dumps(data)) |
| |
| if "input" in log_data: |
| if "image" in log_data["input"]: |
| image_data = log_data["input"]["image"] |
| if image_data.startswith("http"): |
| log_data["input"]["image"] = f"[IMAGE_URL: {image_data}]" |
| else: |
| log_data["input"]["image"] = f"[IMAGE_DATA_{len(image_data)}]" |
| if "prompt" in log_data["input"] and len(log_data["input"]["prompt"]) > 1000: |
| log_data["input"]["prompt"] = log_data["input"]["prompt"][:1000] + "...[TRUNCATED]" |
| |
| return log_data |
|
|
| async def create_replicate_prediction(session: aiohttp.ClientSession, model: str, data: Dict[str, Any]) -> Dict[str, Any]: |
| """创建Replicate预测""" |
| try: |
| url = f"{REPLICATE_BASE_URL}/models/{model}/predictions" |
| headers = { |
| "Authorization": f"Bearer {REPLICATE_API_TOKEN}", |
| "Content-Type": "application/json" |
| } |
| |
| logger.info(f"Creating prediction for model: {model}") |
| |
| |
| log_data = create_log_safe_data(data) |
| logger.info(f"Request data: {json.dumps(log_data, indent=2)}") |
| |
| async with session.post(url, headers=headers, json=data, timeout=30) as response: |
| response_text = await response.text() |
| logger.info(f"Replicate response status: {response.status}") |
| |
| if response.status != 201: |
| logger.error(f"Replicate API error: {response.status} - {response_text}") |
| raise HTTPException( |
| status_code=response.status, |
| detail=f"Replicate API error: {response_text}" |
| ) |
| |
| return json.loads(response_text) |
| |
| except asyncio.TimeoutError: |
| logger.error("Timeout creating Replicate prediction") |
| raise HTTPException(status_code=504, detail="Timeout creating prediction") |
| except Exception as e: |
| logger.error(f"Error creating prediction: {str(e)}") |
| raise HTTPException(status_code=500, detail=f"Prediction creation error: {str(e)}") |
|
|
| class SSEParser: |
| """Server-Sent Events 解析器""" |
| def __init__(self): |
| self.event_type = None |
| self.event_id = None |
| self.data_buffer = [] |
| |
| def parse_line(self, line: str): |
| """解析 SSE 格式的一行""" |
| if line.startswith('event: '): |
| self.event_type = line[7:].strip() |
| elif line.startswith('id: '): |
| self.event_id = line[4:].strip() |
| elif line.startswith('data: '): |
| self.data_buffer.append(line[6:]) |
| elif line.startswith(': '): |
| |
| pass |
| elif line == '': |
| |
| if self.data_buffer or self.event_type: |
| data = '\n'.join(self.data_buffer) |
| event = { |
| 'event': self.event_type, |
| 'id': self.event_id, |
| 'data': data |
| } |
| |
| self.event_type = None |
| self.event_id = None |
| self.data_buffer = [] |
| return event |
| return None |
|
|
| def create_openai_chunk(content: str, model: str, prediction_id: str, finish_reason=None): |
| """创建 OpenAI 格式的流式响应块""" |
| chunk = { |
| "id": f"chatcmpl-{prediction_id}", |
| "object": "chat.completion.chunk", |
| "created": int(asyncio.get_event_loop().time()), |
| "model": model, |
| "choices": [{ |
| "index": 0, |
| "delta": {}, |
| "finish_reason": finish_reason |
| }] |
| } |
| |
| if content and not finish_reason: |
| chunk["choices"][0]["delta"]["content"] = content |
| |
| return f"data: {json.dumps(chunk)}\n\n" |
|
|
| @app.get("/") |
| async def root(): |
| """健康检查端点""" |
| return { |
| "message": "Replicate API Proxy for LobeChat with Vision and File Support", |
| "status": "running", |
| "replicate_token_configured": bool(REPLICATE_API_TOKEN), |
| "imgbb_token_configured": bool(IMGBB_API_KEY), |
| "version": "1.3.0", |
| "supported_models": list(MODEL_CONFIGS.keys()), |
| "vision_support": True, |
| "file_support": True, |
| "external_image_support": True, |
| "supported_text_files": list(SUPPORTED_TEXT_EXTENSIONS), |
| "supported_image_files": list(SUPPORTED_IMAGE_EXTENSIONS), |
| "claude4_vision_support": "Full support via imgbb image hosting" |
| } |
|
|
| @app.get("/health") |
| async def health(): |
| """详细健康检查""" |
| return { |
| "status": "healthy", |
| "replicate_token": "configured" if REPLICATE_API_TOKEN else "missing", |
| "imgbb_token": "configured" if IMGBB_API_KEY else "missing", |
| "timestamp": asyncio.get_event_loop().time(), |
| "model_configs": MODEL_CONFIGS, |
| "supported_file_types": { |
| "text": list(SUPPORTED_TEXT_EXTENSIONS), |
| "image": list(SUPPORTED_IMAGE_EXTENSIONS) |
| } |
| } |
|
|
| @app.get("/v1/models") |
| async def list_models(): |
| """列出可用模型(兼容OpenAI API)""" |
| models = [] |
| for model_id in ["claude-4-sonnet", "claude-3.5-sonnet", "claude-3.5-haiku", "claude-3-sonnet", "claude-3-haiku"]: |
| models.append({ |
| "id": model_id, |
| "object": "model", |
| "created": 1677610602, |
| "owned_by": "anthropic" |
| }) |
| |
| return {"object": "list", "data": models} |
|
|
| @app.post("/v1/chat/completions") |
| async def chat_completions(request: Request): |
| """处理聊天完成请求(兼容OpenAI API)""" |
| if not REPLICATE_API_TOKEN: |
| logger.error("REPLICATE_API_TOKEN not configured") |
| raise HTTPException(status_code=500, detail="REPLICATE_API_TOKEN not configured") |
| |
| try: |
| body = await request.json() |
| logger.info(f"Received chat completion request") |
| logger.info(f"Client parameters: max_tokens={body.get('max_tokens', 'not set')}, temperature={body.get('temperature', 'not set')}") |
| logger.info(f"Message count: {len(body.get('messages', []))}") |
| |
| |
| replicate_data, model = await transform_openai_to_replicate(body) |
| |
| if body.get("stream", False): |
| |
| async def generate_stream(): |
| |
| async with aiohttp.ClientSession() as session: |
| try: |
| |
| prediction = await create_replicate_prediction(session, model, replicate_data) |
| prediction_id = prediction.get('id') |
| logger.info(f"Created prediction: {prediction_id}") |
| |
| |
| stream_url = prediction.get("urls", {}).get("stream") |
| if not stream_url: |
| error_response = { |
| "error": { |
| "message": "Stream URL not available", |
| "type": "stream_error" |
| } |
| } |
| yield f"data: {json.dumps(error_response)}\n\n" |
| return |
| |
| logger.info(f"Starting stream from: {stream_url}") |
| |
| |
| headers = { |
| "Accept": "text/event-stream", |
| "Cache-Control": "no-store" |
| } |
| |
| sse_parser = SSEParser() |
| |
| async with session.get(stream_url, headers=headers, timeout=120) as response: |
| if response.status != 200: |
| error_text = await response.text() |
| logger.error(f"Stream error: {response.status} - {error_text}") |
| error_response = { |
| "error": { |
| "message": f"Stream error: {error_text}", |
| "type": "stream_error" |
| } |
| } |
| yield f"data: {json.dumps(error_response)}\n\n" |
| return |
| |
| async for line in response.content: |
| line = line.decode('utf-8').rstrip('\r\n') |
| |
| |
| if '408' in line or 'timeout' in line.lower(): |
| logger.info(f"Ignoring timeout message: {line}") |
| continue |
| |
| |
| event = sse_parser.parse_line(line) |
| if event: |
| event_type = event.get('event') |
| data = event.get('data', '') |
| |
| if event_type == 'output' and data.strip(): |
| |
| yield create_openai_chunk(data, model, prediction_id) |
| elif event_type == 'done': |
| |
| logger.info("Stream completed with done event") |
| yield create_openai_chunk("", model, prediction_id, "stop") |
| yield "data: [DONE]\n\n" |
| return |
| |
| |
| logger.info("Stream ended without done event, sending manual completion") |
| yield create_openai_chunk("", model, prediction_id, "stop") |
| yield "data: [DONE]\n\n" |
| |
| except asyncio.TimeoutError: |
| logger.error("Stream timeout") |
| yield create_openai_chunk("", model, prediction_id or "unknown", "stop") |
| yield "data: [DONE]\n\n" |
| except Exception as e: |
| logger.error(f"Stream generation error: {e}") |
| logger.error(f"Traceback: {traceback.format_exc()}") |
| error_response = { |
| "error": { |
| "message": str(e), |
| "type": "stream_error" |
| } |
| } |
| yield f"data: {json.dumps(error_response)}\n\n" |
| |
| return StreamingResponse( |
| generate_stream(), |
| media_type="text/event-stream", |
| headers={ |
| "Cache-Control": "no-cache", |
| "Connection": "keep-alive", |
| "Access-Control-Allow-Origin": "*", |
| "X-Accel-Buffering": "no", |
| } |
| ) |
| |
| else: |
| |
| async with aiohttp.ClientSession() as session: |
| |
| prediction = await create_replicate_prediction(session, model, replicate_data) |
| prediction_id = prediction.get('id') |
| logger.info(f"Created prediction: {prediction_id}") |
| |
| |
| prediction_url = f"{REPLICATE_BASE_URL}/predictions/{prediction_id}" |
| headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}"} |
| |
| max_attempts = 60 |
| attempt = 0 |
| |
| while attempt < max_attempts: |
| async with session.get(prediction_url, headers=headers) as response: |
| result = await response.json() |
| status = result.get("status") |
| |
| logger.info(f"Prediction {prediction_id} status: {status}") |
| |
| if status == "succeeded": |
| output = result.get("output", []) |
| content = "".join(output) if isinstance(output, list) else str(output) |
| |
| openai_response = { |
| "id": f"chatcmpl-{prediction_id}", |
| "object": "chat.completion", |
| "created": int(asyncio.get_event_loop().time()), |
| "model": model, |
| "choices": [{ |
| "index": 0, |
| "message": { |
| "role": "assistant", |
| "content": content |
| }, |
| "finish_reason": "stop" |
| }], |
| "usage": { |
| "prompt_tokens": 0, |
| "completion_tokens": len(content.split()), |
| "total_tokens": len(content.split()) |
| } |
| } |
| return openai_response |
| |
| elif status == "failed": |
| error_msg = result.get('error', 'Unknown error') |
| logger.error(f"Prediction failed: {error_msg}") |
| raise HTTPException(status_code=500, detail=f"Prediction failed: {error_msg}") |
| |
| elif status in ["canceled", "cancelled"]: |
| raise HTTPException(status_code=500, detail="Prediction was canceled") |
| |
| |
| await asyncio.sleep(1) |
| attempt += 1 |
| |
| raise HTTPException(status_code=504, detail="Prediction timeout") |
| |
| except HTTPException: |
| raise |
| except Exception as e: |
| logger.error(f"Unexpected error processing request: {str(e)}") |
| logger.error(f"Traceback: {traceback.format_exc()}") |
| raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") |
|
|
| if __name__ == "__main__": |
| port = int(os.getenv("PORT", 7860)) |
| logger.info(f"Starting server on port {port}") |
| uvicorn.run(app, host="0.0.0.0", port=port, log_level="info") |