| """Model wrapper for LiteLLM""" |
|
|
| import os |
| import json |
| from typing import List, Dict, Any, Optional |
|
|
| try: |
| import litellm |
| except ImportError: |
| print("⚠️ litellm not installed. Install with: pip install litellm") |
| litellm = None |
|
|
|
|
| class LiteLLMModel: |
| """Wrapper for LiteLLM models""" |
| |
| def __init__(self, model_id: str): |
| self.model_id = model_id |
| |
| |
| if "groq" in model_id.lower(): |
| if not os.getenv("GROQ_API_KEY"): |
| print("⚠️ GROQ_API_KEY not set in environment") |
| raise RuntimeError("GROQ_API_KEY not set. Please add it to your Space secrets.") |
| |
| def generate(self, messages: List[Dict], tools: Optional[List] = None) -> Dict: |
| if not litellm: |
| return {"content": "Unknown - litellm not installed"} |
| |
| try: |
| formatted_tools = None |
| if tools: |
| formatted_tools = [ |
| { |
| "type": "function", |
| "function": { |
| "name": tool.name, |
| "description": tool.description, |
| "parameters": tool.parameters |
| } |
| } |
| for tool in tools |
| ] |
| |
| |
| if "groq" in self.model_id.lower(): |
| api_key = os.getenv("GROQ_API_KEY") |
| if not api_key: |
| raise RuntimeError("GROQ_API_KEY not set in environment") |
| |
| print(f"DEBUG: Using Groq model: {self.model_id}") |
| |
| response = litellm.completion( |
| model=self.model_id, |
| api_key=api_key, |
| messages=messages, |
| tools=formatted_tools, |
| temperature=0.1 |
| ) |
| else: |
| |
| response = litellm.completion( |
| model=self.model_id, |
| messages=messages, |
| tools=formatted_tools, |
| temperature=0.1 |
| ) |
| |
| message = response.choices[0].message |
| result = { |
| "content": message.content or "" |
| } |
| |
| if hasattr(message, 'tool_calls') and message.tool_calls: |
| result["tool_calls"] = [] |
| for tc in message.tool_calls: |
| |
| args = tc.function.arguments |
| if isinstance(args, str): |
| try: |
| args = json.loads(args) |
| except: |
| args = {} |
| |
| result["tool_calls"].append({ |
| "id": tc.id if hasattr(tc, 'id') else f"call_{tc.function.name}", |
| "name": tc.function.name, |
| "arguments": args |
| }) |
| |
| return result |
| |
| except Exception as e: |
| print(f"Model error: {e}") |
| return {"content": "Unknown"} |
|
|
|
|
| def get_model(model_type: str, model_id: str): |
| if model_type == "LiteLLMModel": |
| return LiteLLMModel(model_id) |
| else: |
| raise ValueError(f"Unknown model type: {model_type}") |