| |
| |
| |
|
|
| import asyncio |
| import json |
| import traceback |
| from typing import List, Optional, Literal, Any, Dict |
| from contextlib import AsyncExitStack |
| import uuid |
|
|
| from fastapi import FastAPI, HTTPException |
| from fastapi.middleware.cors import CORSMiddleware |
| from pydantic import BaseModel |
| from fastapi.responses import FileResponse |
| from openai import APIConnectionError, OpenAI |
|
|
| from mcp import ClientSession, StdioServerParameters |
| from mcp.client.stdio import stdio_client |
|
|
| sessions = {} |
| unique_apikeys = [] |
|
|
| class MCPClient: |
| def __init__(self): |
| self.session: Optional[ClientSession] = None |
| self.exit_stack = AsyncExitStack() |
| self.current_model = None |
| self.messages = [] |
| self.openai: Optional[OpenAI] = None |
| self.api_key = None |
| self.tool_use: bool = True |
| self.models = None |
| self.tools = [] |
| |
| async def connect(self, api_key: str): |
| try: |
| self.openai = OpenAI( |
| base_url="https://generativelanguage.googleapis.com/v1beta/openai/", |
| api_key=api_key |
| ) |
| self.api_key = api_key |
| except APIConnectionError as e: |
| traceback.print_exception(e) |
| return False |
| except Exception as e: |
| traceback.print_exception(e) |
| return False |
|
|
| server_params = StdioServerParameters(command="uv", args=["--directory", "/app", "run", "server.py"]) |
| stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params)) |
| self.stdio, self.write = stdio_transport |
| self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write)) |
|
|
| await self.session.initialize() |
|
|
| response = await self.session.list_tools() |
| tools = response.tools |
| self.tools = [{ |
| "type": "function", |
| "function": { |
| "name": tool.name, |
| "description": tool.description, |
| "parameters": tool.inputSchema |
| } |
| } for tool in tools] |
|
|
| def populate_model(self): |
| self.models = sorted([m.id for m in self.openai.models.list().data]) |
|
|
| async def process_query(self, query: str) -> str: |
| """Process a query using Groq and available tools""" |
| self.messages.extend([ |
| { |
| "role": "user", |
| "content": query |
| } |
| ]) |
|
|
| response = self.openai.chat.completions.create( |
| model=self.current_model, |
| messages=self.messages, |
| tools=self.tools, |
| temperature=0 |
| ) if self.tool_use else self.openai.chat.completions.create( |
| model=self.current_model, |
| messages=self.messages, |
| temperature=0.65 |
| ) |
|
|
| |
| final_text = [] |
|
|
| for choice in response.choices: |
| content = choice.message.content |
| tool_calls = choice.message.tool_calls |
| if content: |
| final_text.append(content) |
| break |
| if tool_calls: |
| assistant_message = { |
| "role": "assistant", |
| "tool_calls": tool_calls |
| } |
|
|
| if content: |
| assistant_message["content"] = content |
|
|
| self.messages.append(assistant_message) |
|
|
| for tool in tool_calls: |
| tool_name = tool.function.name |
| tool_args = tool.function.arguments |
|
|
| result = await self.session.call_tool(tool_name, json.loads(tool_args)) |
| print(f"[Calling tool {tool_name} with args {tool_args}]") |
|
|
| self.messages.append({ |
| "role": "tool", |
| "tool_call_id": tool.id, |
| "content": str(result) |
| }) |
|
|
| response2 = self.openai.chat.completions.create( |
| model=self.current_model, |
| messages=self.messages, |
| temperature=0.7 |
| ) |
|
|
| final_text.append(response2.choices[0].message.content) |
|
|
| self.messages.append({ |
| "role": "assistant", |
| "content": response2.choices[0].message.content |
| }) |
| return "\n".join(final_text) |
|
|
| app = FastAPI() |
| app.add_middleware(CORSMiddleware, allow_credentials=True, allow_headers=["*"], allow_methods=["*"], allow_origins=["*"]) |
| mcp = MCPClient() |
|
|
| class InitRequest(BaseModel): |
| api_key: str |
|
|
| class InitResponse(BaseModel): |
| success: bool |
| session_id: str |
| models: Optional[list] = None |
| error: Optional[str] = None |
|
|
| class LogoutRequest(BaseModel): |
| session_id: str |
|
|
| def get_mcp_client(session_id: str) -> MCPClient|None: |
| """Get the MCPClient for a given session_id, or raise 404.""" |
| client = sessions.get(session_id) |
| if client is None: |
| raise HTTPException(status_code=404, detail="Invalid session_id. Please re-initialize.") |
| return client |
|
|
| @app.get("/") |
| def root(): |
| return FileResponse("index.html") |
|
|
| @app.post("/init", response_model=InitResponse) |
| async def init_server(req: InitRequest): |
| """ |
| Initializes a new MCP client session. Returns a session_id. |
| """ |
| api_key = req.api_key |
| session_id = str(uuid.uuid4()) |
| mcp = MCPClient() |
|
|
| try: |
| ok = await mcp.connect(api_key) |
| if ok is False: |
| raise RuntimeError("Failed to connect to MCP or OpenAI with API key.") |
| mcp.populate_model() |
| |
| sessions[session_id] = mcp |
| if api_key not in unique_apikeys: |
| unique_apikeys.append(api_key) |
| else: |
| raise Exception("Session with this API key already exists. We won't re-return you the session ID. Bye-bye Hacker !!") |
| return InitResponse( |
| session_id=session_id, |
| models=mcp.models, |
| error=None, |
| success=True |
| ) |
| except Exception as e: |
| traceback.print_exception(e) |
| return InitResponse( |
| session_id="", |
| models=None, |
| error=str(e), |
| success=False |
| ) |
|
|
| class ChatRequest(BaseModel): |
| session_id: str |
| query: str |
| tool_use: Optional[bool] = True |
| model: Optional[str] = "models/gemini-2.0-flash" |
|
|
| class ChatResponse(BaseModel): |
| output: str |
| error: Optional[str] = None |
|
|
| @app.post("/chat", response_model=ChatResponse) |
| async def chat(req: ChatRequest): |
| """ |
| Handles chat requests for a given session. |
| """ |
| try: |
| mcp = get_mcp_client(req.session_id) |
| mcp.tool_use = req.tool_use |
| if req.model in mcp.models: |
| mcp.current_model = req.model |
| else: |
| raise ValueError(f"Model not recognized: Not in the model list: {mcp.models}") |
| result = await mcp.process_query(req.query) |
| return ChatResponse(output=result) |
| except Exception as e: |
| traceback.print_exception(e) |
| return ChatResponse(output="", error=str(e)) |
| |
| @app.post("/logout") |
| async def logout(logout_req: LogoutRequest): |
| """Clean up session resources.""" |
| mcp = sessions.pop(logout_req.session_id, None) |
| unique_apikeys.remove(mcp.api_key) |
| if mcp and hasattr(mcp.exit_stack, "aclose"): |
| try: |
| await mcp.exit_stack.aclose() |
| except RuntimeError: |
| pass |
| return {"success": True} |