| import os |
| import gradio as gr |
| import requests |
| import pandas as pd |
| from typing import Dict, List, Any, Optional, TypedDict, Annotated |
| import re |
| import numpy as np |
| from datetime import datetime |
|
|
| |
| from langchain_anthropic import ChatAnthropic |
| from langchain_core.messages import HumanMessage, SystemMessage, BaseMessage, AIMessage |
| from langchain_core.tools import tool |
| from serpapi import GoogleSearch |
| from langgraph.graph import StateGraph, END |
| from langgraph.prebuilt import ToolNode |
| from langgraph.graph.message import add_messages |
| import numexpr |
| from dotenv import load_dotenv |
|
|
| |
| load_dotenv() |
|
|
| |
| DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space" |
|
|
| |
| class AgentState(TypedDict): |
| messages: Annotated[List[BaseMessage], add_messages] |
|
|
| |
| @tool |
| def web_search(query: str, max_results: int = 8) -> str: |
| """ |
| Enhanced web search using DuckDuckGo (no API key required). |
| Falls back to SerpAPI if available. |
| """ |
| try: |
| |
| if isinstance(query, list): |
| query = " ".join(str(item) for item in query) |
| elif not isinstance(query, str): |
| query = str(query) |
| |
| |
| tavily_api_key = os.getenv("TAVILY_API_KEY") |
| if tavily_api_key: |
| try: |
| import requests |
| tavily_url = "https://api.tavily.com/search" |
| tavily_headers = { |
| "Content-Type": "application/json" |
| } |
| tavily_data = { |
| "api_key": tavily_api_key, |
| "query": query, |
| "search_depth": "advanced", |
| "include_answer": True, |
| "include_raw_content": False, |
| "max_results": max_results |
| } |
| |
| response = requests.post(tavily_url, json=tavily_data, headers=tavily_headers, timeout=10) |
| if response.status_code == 200: |
| results = response.json() |
| formatted_results = [] |
| |
| |
| if results.get("answer"): |
| formatted_results.append(f"DIRECT ANSWER: {results['answer']}") |
| |
| |
| if results.get("results"): |
| for i, result in enumerate(results["results"][:max_results], 1): |
| title = result.get("title", "") |
| content = result.get("content", "") |
| url = result.get("url", "") |
| formatted_results.append(f"{i}. {title}\n {content}\n Source: {url}") |
| |
| if formatted_results: |
| return "\n\n".join(formatted_results) |
| |
| except Exception as tavily_error: |
| print(f"Tavily search error: {tavily_error}") |
| |
| |
| try: |
| import requests |
| from urllib.parse import quote |
| |
| |
| ddg_success = False |
| formatted_results = [] |
| |
| |
| for attempt in range(2): |
| try: |
| ddg_url = f"https://api.duckduckgo.com/?q={quote(query)}&format=json&no_html=1" |
| response = requests.get(ddg_url, timeout=5) |
| |
| if response.status_code == 200: |
| ddg_data = response.json() |
| |
| |
| if ddg_data.get("Answer"): |
| formatted_results.append(f"DIRECT ANSWER: {ddg_data['Answer']}") |
| ddg_success = True |
| |
| |
| if ddg_data.get("Abstract"): |
| formatted_results.append(f"SUMMARY: {ddg_data['Abstract']}") |
| ddg_success = True |
| |
| |
| if ddg_data.get("Definition"): |
| formatted_results.append(f"DEFINITION: {ddg_data['Definition']}") |
| ddg_success = True |
| |
| if ddg_success: |
| break |
| except: |
| if attempt == 0: |
| print(f"DuckDuckGo attempt 1 failed, retrying...") |
| continue |
| |
| |
| if not ddg_success: |
| print(f"DuckDuckGo unavailable, checking alternatives...") |
| |
| |
| if "wikipedia" in query.lower() or "featured article" in query.lower(): |
| formatted_results.append(f"Search query: {query}") |
| formatted_results.append("Note: For Wikipedia Featured Articles, check Wikipedia's FA archives") |
| formatted_results.append("Tip: Featured Articles are promoted monthly and listed in Wikipedia's FA log") |
| else: |
| |
| query_lower = query.lower() if isinstance(query, str) else str(query).lower() |
| if "who is" in query_lower or "who was" in query_lower: |
| formatted_results.append(f"Search query: {query}") |
| formatted_results.append("Note: Live web search unavailable. Please verify information.") |
| elif any(word in query_lower for word in ["when", "what year", "what date"]): |
| formatted_results.append(f"Search query: {query}") |
| formatted_results.append("Note: For current dates and recent events, web search is limited.") |
| else: |
| formatted_results.append(f"Search query: {query}") |
| formatted_results.append("Note: Web search temporarily unavailable.") |
| |
| if formatted_results: |
| return "\n\n".join(formatted_results) |
| |
| except Exception as ddg_error: |
| print(f"DuckDuckGo search error: {ddg_error}") |
| |
| |
| api_key = os.getenv("SERPAPI_KEY") |
| if api_key: |
| params = { |
| "q": query, |
| "api_key": api_key, |
| "num": max_results, |
| "engine": "google", |
| "hl": "en", |
| "gl": "us" |
| } |
| |
| search = GoogleSearch(params) |
| results = search.get_dict() |
| |
| formatted_results = [] |
| |
| |
| if "answer_box" in results: |
| ab = results["answer_box"] |
| if "answer" in ab: |
| formatted_results.append(f"DIRECT ANSWER: {ab['answer']}") |
| elif "snippet" in ab: |
| formatted_results.append(f"ANSWER BOX: {ab['snippet']}") |
| |
| if "organic_results" in results: |
| for i, result in enumerate(results["organic_results"][:max_results], 1): |
| title = result.get("title", "") |
| snippet = result.get("snippet", "") |
| formatted_results.append(f"{i}. {title}\n {snippet}") |
| |
| return "\n\n".join(formatted_results) if formatted_results else "No results found" |
| |
| return "No search service available. Please set SERPAPI_KEY or check internet connection." |
| |
| except Exception as e: |
| return f"Search error: {str(e)}" |
|
|
| @tool |
| def calculator(expression: str) -> str: |
| """ |
| Enhanced calculator with unit conversion and advanced functions. |
| Supports: arithmetic, percentages, trigonometry, logarithms, unit conversion. |
| Examples: "15% of 200", "sqrt(16)", "convert 5 km to miles" |
| """ |
| try: |
| |
| if isinstance(expression, list): |
| expression = " ".join(str(item) for item in expression) |
| elif not isinstance(expression, str): |
| expression = str(expression) |
| |
| expression = expression.strip().lower() |
| |
| |
| if "% of" in expression: |
| parts = expression.split("% of") |
| if len(parts) == 2: |
| percent = float(parts[0].strip()) |
| value = float(parts[1].strip()) |
| result = (percent / 100) * value |
| return str(result) |
| |
| |
| if "convert" in expression or " to " in expression: |
| |
| conversions = { |
| "km to miles": 0.621371, |
| "miles to km": 1.60934, |
| "kg to lbs": 2.20462, |
| "lbs to kg": 0.453592, |
| "celsius to fahrenheit": lambda c: (c * 9/5) + 32, |
| "fahrenheit to celsius": lambda f: (f - 32) * 5/9, |
| "meters to feet": 3.28084, |
| "feet to meters": 0.3048, |
| "liters to gallons": 0.264172, |
| "gallons to liters": 3.78541 |
| } |
| |
| for conv, factor in conversions.items(): |
| if conv in expression: |
| |
| import re |
| numbers = re.findall(r'[\d.]+', expression) |
| if numbers: |
| value = float(numbers[0]) |
| if callable(factor): |
| result = factor(value) |
| else: |
| result = value * factor |
| return f"{result:.4f}".rstrip('0').rstrip('.') |
| |
| |
| expression = expression.replace("sqrt", "sqrt") |
| expression = expression.replace("log10", "log10") |
| expression = expression.replace("log", "log") |
| expression = expression.replace("sin", "sin") |
| expression = expression.replace("cos", "cos") |
| expression = expression.replace("tan", "tan") |
| expression = expression.replace("pi", "3.14159265359") |
| expression = expression.replace("e", "2.71828182846") |
| |
| |
| expression = re.sub(r'[a-zA-Z]+', '', expression) |
| |
| |
| result = numexpr.evaluate(expression) |
| |
| |
| if isinstance(result, (int, np.integer)): |
| return str(int(result)) |
| elif isinstance(result, (float, np.floating)): |
| if abs(result) < 1e-10: |
| return "0" |
| elif abs(result) > 1e10: |
| return f"{result:.2e}" |
| else: |
| |
| formatted = f"{result:.6f}".rstrip('0').rstrip('.') |
| |
| if float(formatted).is_integer(): |
| return str(int(float(formatted))) |
| return formatted |
| else: |
| return str(result) |
| |
| except Exception as e: |
| |
| try: |
| import math |
| result = eval(expression, {"__builtins__": {}, "math": math}) |
| if isinstance(result, float) and result.is_integer(): |
| return str(int(result)) |
| return str(result) |
| except: |
| return f"Calculation error: {str(e)}" |
|
|
| @tool |
| def python_executor(code: str) -> str: |
| """ |
| Enhanced Python executor with data analysis and web scraping capabilities. |
| Includes: pandas, numpy, statistics, datetime, requests, BeautifulSoup. |
| Always print the final result you want to return. |
| """ |
| try: |
| |
| if isinstance(code, list): |
| code = "\n".join(str(item) for item in code) |
| elif not isinstance(code, str): |
| code = str(code) |
| |
| safe_globals = { |
| '__builtins__': { |
| 'print': print, |
| 'len': len, |
| 'range': range, |
| 'sum': sum, |
| 'min': min, |
| 'max': max, |
| 'abs': abs, |
| 'round': round, |
| 'sorted': sorted, |
| 'reversed': reversed, |
| 'enumerate': enumerate, |
| 'zip': zip, |
| 'map': map, |
| 'filter': filter, |
| 'str': str, |
| 'int': int, |
| 'float': float, |
| 'list': list, |
| 'dict': dict, |
| 'set': set, |
| 'tuple': tuple, |
| 'bool': bool, |
| 'all': all, |
| 'any': any, |
| 'isinstance': isinstance, |
| 'type': type, |
| }, |
| 'math': __import__('math'), |
| 'datetime': __import__('datetime'), |
| 'json': __import__('json'), |
| 're': __import__('re'), |
| 'numpy': __import__('numpy'), |
| 'np': __import__('numpy'), |
| 'pandas': __import__('pandas'), |
| 'pd': __import__('pandas'), |
| 'statistics': __import__('statistics'), |
| 'itertools': __import__('itertools'), |
| 'collections': __import__('collections'), |
| 'Counter': __import__('collections').Counter, |
| 'defaultdict': __import__('collections').defaultdict, |
| } |
| |
| |
| from io import StringIO |
| import sys |
| |
| old_stdout = sys.stdout |
| sys.stdout = output_buffer = StringIO() |
| |
| try: |
| |
| enhanced_code = code |
| if "from datetime" not in code and "import datetime" not in code: |
| enhanced_code = "from datetime import datetime, date, timedelta\n" + enhanced_code |
| |
| exec(enhanced_code, safe_globals) |
| output = output_buffer.getvalue().strip() |
| |
| |
| if not output: |
| for var in ['result', 'answer', 'output']: |
| if var in safe_globals: |
| output = str(safe_globals[var]) |
| break |
| |
| return output if output else "No output (add print statement)" |
| finally: |
| sys.stdout = old_stdout |
| |
| except Exception as e: |
| import traceback |
| return f"Error: {str(e)}\nTraceback: {traceback.format_exc()}" |
|
|
| @tool |
| def extract_image_from_question(question: str) -> str: |
| """ |
| Extract and analyze images mentioned in questions. |
| For GAIA benchmark, images are typically base64 encoded or referenced. |
| """ |
| try: |
| |
| if isinstance(question, list): |
| question = " ".join(str(item) for item in question) |
| elif not isinstance(question, str): |
| question = str(question) |
| |
| if "data:image" in question: |
| return "Image data detected in question" |
| |
| |
| image_extensions = ['.png', '.jpg', '.jpeg', '.gif', '.bmp', '.svg'] |
| for ext in image_extensions: |
| if ext in question.lower(): |
| return f"Image file reference detected: {ext}" |
| |
| |
| image_phrases = ['image', 'picture', 'photo', 'diagram', 'figure', 'screenshot'] |
| for phrase in image_phrases: |
| if phrase in question.lower(): |
| return "Image-related content mentioned in question" |
| |
| return "No image content detected" |
| except Exception as e: |
| return f"Error analyzing for images: {str(e)}" |
|
|
| @tool |
| def analyze_attachments(question: str) -> str: |
| """ |
| Analyze questions for references to attachments (files, videos, audio). |
| For GAIA questions that reference external content. |
| """ |
| |
| if isinstance(question, list): |
| question = " ".join(str(item) for item in question) |
| elif not isinstance(question, str): |
| question = str(question) |
| |
| attachments = [] |
| |
| |
| youtube_patterns = [ |
| r'youtube\.com/watch\?v=([a-zA-Z0-9_-]+)', |
| r'youtu\.be/([a-zA-Z0-9_-]+)' |
| ] |
| for pattern in youtube_patterns: |
| import re |
| matches = re.findall(pattern, question) |
| if matches: |
| attachments.append(f"YouTube video: {matches[0]}") |
| |
| |
| url_pattern = r'https?://[^\s<>"{}|\\^`\[\]]+\.(?:xlsx|xls|csv|pdf|txt)' |
| url_matches = re.findall(url_pattern, question, re.IGNORECASE) |
| if url_matches: |
| for url in url_matches: |
| if '.xlsx' in url or '.xls' in url: |
| attachments.append(f"Excel file URL: {url}") |
| elif '.csv' in url: |
| attachments.append(f"CSV file URL: {url}") |
| elif '.pdf' in url: |
| attachments.append(f"PDF file URL: {url}") |
| elif '.txt' in url: |
| attachments.append(f"Text file URL: {url}") |
| |
| |
| file_patterns = [ |
| r'attached (\w+) file', |
| r'the (\w+) file', |
| r'(\w+\.\w{2,4})' |
| ] |
| for pattern in file_patterns: |
| matches = re.findall(pattern, question, re.IGNORECASE) |
| if matches: |
| |
| for match in matches: |
| if not any(match in url for url in url_matches): |
| attachments.append(f"File reference: {match}") |
| |
| if attachments: |
| return "Attachments found: " + ", ".join(attachments) |
| return "No attachments detected" |
|
|
| @tool |
| def analyze_reversed_text(text: str) -> str: |
| """ |
| Analyze text that might be written backwards or contains puzzles. |
| Useful for GAIA questions with reversed text. |
| """ |
| try: |
| |
| if isinstance(text, list): |
| text = " ".join(str(item) for item in text) |
| elif not isinstance(text, str): |
| text = str(text) |
| |
| reversed_text = text[::-1] |
| |
| |
| if "rewsna" in text.lower() or "noitseuq" in text.lower(): |
| return f"Text appears to be reversed. Original: {reversed_text}" |
| |
| |
| words = text.split() |
| reversed_words = [word[::-1] for word in words] |
| |
| return f"Normal text: {text}\nReversed text: {reversed_text}\nReversed words: {' '.join(reversed_words)}" |
| except Exception as e: |
| return f"Error analyzing text: {str(e)}" |
|
|
| @tool |
| def analyze_code_in_question(question: str) -> str: |
| """ |
| Detect and extract Python code from questions. |
| Looks for code blocks, inline code, and code-related phrases. |
| """ |
| try: |
| |
| if isinstance(question, list): |
| question = " ".join(str(item) for item in question) |
| elif not isinstance(question, str): |
| question = str(question) |
| |
| extracted_code = [] |
| |
| |
| code_block_pattern = r'```python\s*(.*?)\s*```' |
| code_blocks = re.findall(code_block_pattern, question, re.DOTALL | re.IGNORECASE) |
| if code_blocks: |
| for i, code in enumerate(code_blocks, 1): |
| extracted_code.append(f"Code Block {i}:\n{code.strip()}") |
| |
| |
| generic_code_pattern = r'```\s*(.*?)\s*```' |
| generic_blocks = re.findall(generic_code_pattern, question, re.DOTALL) |
| if generic_blocks: |
| for i, code in enumerate(generic_blocks, 1): |
| |
| if any(keyword in code for keyword in ['def ', 'import ', 'class ', 'if ', 'for ', 'while ', 'print(', 'return ']): |
| extracted_code.append(f"Generic Code Block {i}:\n{code.strip()}") |
| |
| |
| inline_code_pattern = r'`([^`]+)`' |
| inline_codes = re.findall(inline_code_pattern, question) |
| if inline_codes: |
| |
| python_inline = [] |
| for code in inline_codes: |
| if any(char in code for char in ['(', ')', '=', '[', ']', '{', '}', 'def', 'import', 'print']): |
| python_inline.append(code) |
| if python_inline: |
| extracted_code.append("Inline Code:\n" + "\n".join(f"- {code}" for code in python_inline)) |
| |
| |
| code_phrases = [ |
| r'attached python code', |
| r'the following code', |
| r'this code', |
| r'given code', |
| r'code snippet', |
| r'python script', |
| r'the script', |
| r'function below', |
| r'class below', |
| r'program below' |
| ] |
| |
| code_indicators = [] |
| for phrase in code_phrases: |
| if re.search(phrase, question, re.IGNORECASE): |
| code_indicators.append(phrase.replace(r'\\', '')) |
| |
| |
| python_patterns = [ |
| r'def\s+\w+\s*\([^)]*\)\s*:', |
| r'class\s+\w+\s*(?:\([^)]*\))?\s*:', |
| r'import\s+\w+', |
| r'from\s+\w+\s+import', |
| r'if\s+.*:\s*\n', |
| r'for\s+\w+\s+in\s+', |
| r'while\s+.*:\s*\n', |
| ] |
| |
| loose_code = [] |
| for pattern in python_patterns: |
| matches = re.findall(pattern, question, re.MULTILINE) |
| if matches: |
| loose_code.extend(matches) |
| |
| if loose_code: |
| extracted_code.append("Detected Python patterns:\n" + "\n".join(f"- {code.strip()}" for code in loose_code[:5])) |
| |
| |
| response_parts = [] |
| |
| if extracted_code: |
| response_parts.append("Found Python code in question:") |
| response_parts.extend(extracted_code) |
| |
| if code_indicators: |
| response_parts.append(f"\nCode-related phrases detected: {', '.join(code_indicators)}") |
| |
| if not extracted_code and not code_indicators: |
| return "No Python code detected in the question" |
| |
| return "\n\n".join(response_parts) |
| |
| except Exception as e: |
| return f"Error analyzing code in question: {str(e)}" |
|
|
| @tool |
| def get_youtube_transcript(url: str) -> str: |
| """ |
| Extract transcript/subtitles from YouTube videos. |
| Useful for questions asking about video content. |
| """ |
| try: |
| |
| if isinstance(url, list): |
| url = " ".join(str(item) for item in url) |
| elif not isinstance(url, str): |
| url = str(url) |
| |
| |
| import re |
| video_id_match = re.search(r'(?:v=|/)([0-9A-Za-z_-]{11}).*', url) |
| if not video_id_match: |
| return "Error: Invalid YouTube URL" |
| |
| video_id = video_id_match.group(1) |
| |
| |
| try: |
| from youtube_transcript_api import YouTubeTranscriptApi |
| import time |
| |
| |
| time.sleep(1) |
| |
| |
| transcript_list = YouTubeTranscriptApi.list_transcripts(video_id) |
| |
| |
| transcript = None |
| try: |
| transcript = transcript_list.find_transcript(['en']) |
| except: |
| |
| try: |
| transcript = transcript_list.find_manually_created_transcript() |
| except: |
| try: |
| transcript = transcript_list.find_generated_transcript() |
| except: |
| pass |
| |
| if transcript: |
| |
| transcript_data = transcript.fetch() |
| |
| |
| if isinstance(transcript_data, list): |
| full_text = " ".join([entry.get('text', '') if isinstance(entry, dict) else str(entry) for entry in transcript_data]) |
| else: |
| |
| full_text = str(transcript_data) |
| |
| |
| if any(phrase in url.lower() or phrase in str(url).lower() |
| for phrase in ["say", "response", "answer", "dialogue"]): |
| |
| return f"Transcript excerpt: ...{full_text[-500:]}" |
| |
| return f"Full transcript: {full_text[:1000]}..." if len(full_text) > 1000 else f"Full transcript: {full_text}" |
| |
| except Exception as yt_error: |
| error_str = str(yt_error) |
| print(f"YouTube transcript error: {yt_error}") |
| |
| |
| if "429" in error_str or "Too Many Requests" in error_str: |
| return "Unable to determine" |
| |
| |
| try: |
| from pytube import YouTube |
| import time |
| |
| |
| time.sleep(1) |
| |
| yt = YouTube(url) |
| |
| |
| title = yt.title if hasattr(yt, 'title') else "Unknown" |
| description = yt.description[:200] if hasattr(yt, 'description') and yt.description else "No description" |
| |
| return f"Video info - Title: {title}\nDescription: {description}\nNote: Transcript not available" |
| |
| except Exception as pytube_error: |
| print(f"Pytube error: {pytube_error}") |
| |
| return "Unable to determine" |
| |
| except Exception as e: |
| return f"Error accessing YouTube video: {str(e)}" |
|
|
| @tool |
| def analyze_multimedia_reference(question: str) -> str: |
| """ |
| Detect and provide guidance for multimedia content in questions. |
| Returns specific answers for common multimedia patterns. |
| """ |
| try: |
| |
| if isinstance(question, list): |
| question = " ".join(str(item) for item in question) |
| elif not isinstance(question, str): |
| question = str(question) |
| |
| question_lower = question.lower() |
| |
| |
| |
| |
| if any(term in question_lower for term in ["excel", "spreadsheet", ".xlsx", ".xls", ".csv"]): |
| if any(term in question_lower for term in ["total", "sum", "how much", "how many", "amount"]): |
| |
| return "Cannot access spreadsheet - provide final answer: Unable to determine" |
| elif "sales" in question_lower and "total" in question_lower: |
| return "Cannot access sales data - provide final answer: Unable to determine" |
| |
| |
| if "attached" in question_lower and ("python" in question_lower or "code" in question_lower): |
| if "output" in question_lower and ("numeric" in question_lower or "final" in question_lower): |
| return "Cannot access attached code - provide final answer: Unable to determine" |
| elif "fix" in question_lower or "correct" in question_lower: |
| return "Cannot access attached code to fix - provide final answer: Unable to determine" |
| |
| |
| if ("pdf" in question_lower or ".pdf" in question_lower) and any(term in question_lower for term in ["how many", "count", "times"]): |
| return "Cannot access PDF to count - provide final answer: Unable to determine" |
| |
| |
| if any(term in question_lower for term in ["image", "picture", "photo", ".png", ".jpg", ".jpeg"]): |
| if "chess" in question_lower: |
| return "Cannot access chess position image - provide final answer: Unable to determine" |
| elif any(term in question_lower for term in ["color", "what is", "describe"]): |
| return "Cannot access image - provide final answer: Unable to determine" |
| |
| |
| if any(term in question_lower for term in ["audio", ".mp3", ".wav", "recording"]): |
| if any(term in question_lower for term in ["transcribe", "what does", "study", "exam"]): |
| return "Cannot access audio file - provide final answer: Unable to determine" |
| |
| return "No specific multimedia pattern requiring 'Unable to determine' response" |
| |
| except Exception as e: |
| return f"Error analyzing multimedia: {str(e)}" |
|
|
| @tool |
| def download_and_process_file(url: str, file_type: str = None) -> str: |
| """ |
| Download and process files from URLs (Excel, CSV, PDF, etc). |
| Useful when questions reference files by URL. |
| """ |
| try: |
| |
| if isinstance(url, list): |
| url = " ".join(str(item) for item in url) |
| elif not isinstance(url, str): |
| url = str(url) |
| |
| |
| url = url.strip() |
| |
| |
| if not file_type: |
| if any(ext in url.lower() for ext in ['.xlsx', '.xls']): |
| file_type = 'excel' |
| elif '.csv' in url.lower(): |
| file_type = 'csv' |
| elif '.pdf' in url.lower(): |
| file_type = 'pdf' |
| elif any(ext in url.lower() for ext in ['.txt', '.text']): |
| file_type = 'text' |
| else: |
| return "Unable to determine file type from URL" |
| |
| |
| import requests |
| from io import BytesIO, StringIO |
| |
| try: |
| response = requests.get(url, timeout=15, headers={'User-Agent': 'Mozilla/5.0'}) |
| response.raise_for_status() |
| except requests.exceptions.RequestException as e: |
| return f"Failed to download file: {str(e)}" |
| |
| |
| if file_type == 'excel': |
| try: |
| import pandas as pd |
| df = pd.read_excel(BytesIO(response.content)) |
| |
| |
| info = [] |
| info.append(f"Excel file loaded successfully") |
| info.append(f"Shape: {df.shape[0]} rows, {df.shape[1]} columns") |
| info.append(f"Columns: {', '.join(df.columns)}") |
| |
| |
| numeric_cols = df.select_dtypes(include=['number']).columns |
| if len(numeric_cols) > 0: |
| info.append("\nNumeric column sums:") |
| for col in numeric_cols: |
| total = df[col].sum() |
| info.append(f" {col}: {total}") |
| |
| |
| if 'sales' in ' '.join(df.columns).lower(): |
| sales_cols = [col for col in df.columns if 'sales' in col.lower()] |
| if sales_cols: |
| total_sales = df[sales_cols].sum().sum() |
| info.append(f"\nTotal sales: {total_sales}") |
| |
| return '\n'.join(info) |
| |
| except Exception as e: |
| return f"Error processing Excel file: {str(e)}" |
| |
| elif file_type == 'csv': |
| try: |
| import pandas as pd |
| df = pd.read_csv(StringIO(response.text)) |
| |
| info = [] |
| info.append(f"CSV file loaded successfully") |
| info.append(f"Shape: {df.shape[0]} rows, {df.shape[1]} columns") |
| info.append(f"Columns: {', '.join(df.columns)}") |
| |
| |
| numeric_cols = df.select_dtypes(include=['number']).columns |
| if len(numeric_cols) > 0: |
| info.append("\nNumeric column sums:") |
| for col in numeric_cols: |
| total = df[col].sum() |
| info.append(f" {col}: {total}") |
| |
| return '\n'.join(info) |
| |
| except Exception as e: |
| return f"Error processing CSV file: {str(e)}" |
| |
| elif file_type == 'pdf': |
| try: |
| import PyPDF2 |
| pdf_reader = PyPDF2.PdfReader(BytesIO(response.content)) |
| |
| info = [] |
| info.append(f"PDF file loaded successfully") |
| info.append(f"Number of pages: {len(pdf_reader.pages)}") |
| |
| |
| full_text = "" |
| for page in pdf_reader.pages: |
| text = page.extract_text() |
| full_text += text + "\n" |
| |
| |
| info.append(f"Total characters: {len(full_text)}") |
| info.append(f"Total words: {len(full_text.split())}") |
| |
| |
| info.append("\nFull text extracted and available for searching") |
| |
| return '\n'.join(info) + f"\n\nFull text (first 1000 chars):\n{full_text[:1000]}..." |
| |
| except Exception as e: |
| return f"Error processing PDF file: {str(e)}" |
| |
| elif file_type == 'text': |
| try: |
| text_content = response.text |
| info = [] |
| info.append(f"Text file loaded successfully") |
| info.append(f"Length: {len(text_content)} characters") |
| info.append(f"Lines: {len(text_content.splitlines())}") |
| info.append(f"\nContent preview:\n{text_content[:500]}...") |
| |
| return '\n'.join(info) |
| |
| except Exception as e: |
| return f"Error processing text file: {str(e)}" |
| |
| else: |
| return f"Unsupported file type: {file_type}" |
| |
| except Exception as e: |
| return f"Error downloading/processing file: {str(e)}" |
|
|
| @tool |
| def extract_file_urls(question: str) -> str: |
| """ |
| Extract file URLs from questions for downloading. |
| Returns URLs of files that can be downloaded. |
| """ |
| try: |
| |
| if isinstance(question, list): |
| question = " ".join(str(item) for item in question) |
| elif not isinstance(question, str): |
| question = str(question) |
| |
| import re |
| |
| |
| url_pattern = r'https?://[^\s<>"{}|\\^`\[\]]+\.(?:xlsx|xls|csv|pdf|txt|doc|docx)' |
| urls = re.findall(url_pattern, question, re.IGNORECASE) |
| |
| if urls: |
| return f"Found downloadable file URLs: {', '.join(urls)}" |
| else: |
| return "No downloadable file URLs found in the question" |
| |
| except Exception as e: |
| return f"Error extracting URLs: {str(e)}" |
|
|
| @tool |
| def get_current_datetime() -> str: |
| """Get the current date and time.""" |
| return datetime.now().strftime("%Y-%m-%d %H:%M:%S %Z") |
|
|
| |
| class LangGraphAgent: |
| def __init__(self, anthropic_api_key: Optional[str] = None): |
| |
| api_key = anthropic_api_key or os.getenv("ANTHROPIC_API_KEY") |
| if not api_key: |
| raise ValueError("ANTHROPIC_API_KEY must be provided or set in environment variables") |
| |
| self.llm = ChatAnthropic( |
| api_key=api_key, |
| model="claude-3-5-sonnet-20241022", |
| temperature=0.3, |
| max_tokens=4096 |
| ) |
| |
| |
| self.tools = [ |
| web_search, |
| calculator, |
| python_executor, |
| extract_image_from_question, |
| analyze_attachments, |
| analyze_reversed_text, |
| analyze_code_in_question, |
| get_youtube_transcript, |
| analyze_multimedia_reference, |
| extract_file_urls, |
| download_and_process_file, |
| get_current_datetime |
| ] |
| |
| |
| self.llm_with_tools = self.llm.bind_tools(self.tools) |
| |
| |
| self.tool_node = ToolNode(self.tools) |
| |
| |
| self.graph = self._build_graph() |
| |
| def _build_graph(self): |
| workflow = StateGraph(AgentState) |
| |
| |
| workflow.add_node("agent", self._call_model) |
| workflow.add_node("tools", self.tool_node) |
| |
| |
| workflow.set_entry_point("agent") |
| |
| |
| workflow.add_conditional_edges( |
| "agent", |
| self._should_continue, |
| { |
| "continue": "tools", |
| "end": END |
| } |
| ) |
| |
| |
| workflow.add_edge("tools", "agent") |
| |
| return workflow.compile() |
| |
| def _call_model(self, state: AgentState): |
| """Call the model with tools.""" |
| messages = state["messages"] |
| response = self.llm_with_tools.invoke(messages) |
| return {"messages": [response]} |
| |
| def _should_continue(self, state: AgentState): |
| """Determine if we should continue with tools or end.""" |
| last_message = state["messages"][-1] |
| |
| |
| if hasattr(last_message, "tool_calls") and last_message.tool_calls: |
| return "continue" |
| |
| |
| tool_call_count = 0 |
| for msg in state["messages"]: |
| if hasattr(msg, "tool_calls") and msg.tool_calls: |
| tool_call_count += len(msg.tool_calls) |
| |
| |
| if tool_call_count < 2: |
| |
| if hasattr(last_message, "content") and last_message.content: |
| content_str = last_message.content if isinstance(last_message.content, str) else str(last_message.content) |
| has_final_answer = "FINAL ANSWER:" in content_str |
| |
| |
| if not has_final_answer and tool_call_count < 3: |
| return "continue" |
| |
| |
| content_str = str(last_message.content) if hasattr(last_message, "content") else "" |
| if tool_call_count >= 6 or "FINAL ANSWER:" in content_str: |
| return "end" |
| |
| return "end" |
| |
| def run(self, question: str) -> str: |
| """Run the agent on a question.""" |
| print(f"\nDEBUG LangGraphAgent.run():") |
| print(f" Input type: {type(question)}") |
| print(f" Input value: {repr(question)[:200]}...") |
| |
| system_prompt = """You are solving GAIA benchmark questions that require deep research and analysis. |
| |
| IMPORTANT: You should: |
| 1. Use multiple tools to thoroughly research the question |
| 2. Search for specific facts, verify information, and perform calculations |
| 3. Think step-by-step and use chain-of-thought reasoning |
| 4. Double-check facts with multiple searches if needed |
| 5. Use python_executor for complex data analysis or calculations |
| |
| At the very end, after all your research and reasoning, provide ONLY the final answer in this format: |
| FINAL ANSWER: [your answer here] |
| |
| The final answer should contain ONLY the requested information: |
| - Numbers: just the number (e.g., "5" not "5 people") |
| - Years: just the year (e.g., "1969") |
| - Names: exact name with proper capitalization |
| - Yes/No: exactly "Yes" or "No" |
| - Lists: comma-separated values |
| |
| Available tools: |
| - web_search: Search for current information (use multiple times with different queries) |
| - calculator: Perform calculations and unit conversions |
| - python_executor: Complex analysis, data processing, date calculations |
| - analyze_attachments: Detect references to external files/media |
| - analyze_reversed_text: Decode backwards or puzzle text |
| - analyze_code_in_question: Extract and analyze Python code from questions |
| - get_youtube_transcript: Extract transcripts from YouTube videos |
| - analyze_multimedia_reference: Handle questions about images, audio, PDFs, Excel files |
| - extract_file_urls: Find downloadable file URLs in questions |
| - download_and_process_file: Download and analyze files from URLs (Excel, CSV, PDF) |
| - get_current_datetime: Get current date/time |
| |
| For questions mentioning "attached code" or containing code snippets: |
| 1. First use analyze_code_in_question to extract the code |
| 2. Then use python_executor to run it and get the output |
| |
| For questions with YouTube videos: |
| 1. Use get_youtube_transcript to extract the video transcript |
| 2. Search the transcript for the relevant information |
| |
| For questions mentioning files with URLs: |
| 1. Use extract_file_urls to find any file URLs in the question |
| 2. If URLs are found, use download_and_process_file to download and analyze the file |
| 3. Extract the specific information requested (totals, counts, etc.) |
| 4. For Excel files asking for totals, sum the relevant columns |
| 5. For PDFs asking for word counts, search the extracted text |
| |
| For questions mentioning attached files without URLs: |
| 1. Use analyze_multimedia_reference to check if file access is needed |
| 2. Return "Unable to determine" if the file cannot be accessed""" |
|
|
| messages = [ |
| SystemMessage(content=system_prompt), |
| HumanMessage(content=question) |
| ] |
| |
| try: |
| |
| config = { |
| "recursion_limit": 25, |
| "configurable": { |
| "thread_id": "gaia_evaluation" |
| } |
| } |
| |
| result = self.graph.invoke({"messages": messages}, config) |
| |
| |
| final_answer = self._extract_final_answer(result["messages"]) |
| return final_answer |
| |
| except Exception as e: |
| return f"Error: {str(e)}" |
| |
| def _extract_final_answer(self, messages: List[BaseMessage]) -> str: |
| """Extract the final answer from the message history.""" |
| |
| for message in reversed(messages): |
| if hasattr(message, "content") and message.content: |
| content = message.content.strip() |
| |
| |
| if "FINAL ANSWER:" in content: |
| parts = content.split("FINAL ANSWER:") |
| if len(parts) >= 2: |
| answer = parts[-1].strip() |
| |
| answer = self._clean_answer(answer) |
| return answer |
| |
| |
| if isinstance(message, AIMessage): |
| return self._clean_answer(content) |
| |
| return "Unable to determine" |
| |
| def _clean_answer(self, answer: str) -> str: |
| """Clean and format the final answer.""" |
| |
| if isinstance(answer, list): |
| answer = " ".join(str(item) for item in answer) |
| elif not isinstance(answer, str): |
| answer = str(answer) |
| |
| answer = answer.strip() |
| |
| |
| if len(answer) > 2 and answer[0] == '"' and answer[-1] == '"': |
| answer = answer[1:-1] |
| |
| |
| prefixes_to_remove = [ |
| "the answer is", "answer:", "based on", "according to", |
| "my research shows", "i found that", "the result is", |
| "after searching", "from the", "it is", "it's", "there are", |
| "there is", "approximately", "about", "around" |
| ] |
| |
| lower_answer = answer.lower() |
| for prefix in prefixes_to_remove: |
| if lower_answer.startswith(prefix): |
| answer = answer[len(prefix):].strip() |
| if answer and answer[0] == ':': |
| answer = answer[1:].strip() |
| lower_answer = answer.lower() |
| |
| |
| if "unable to" in lower_answer or "cannot" in lower_answer: |
| return "Unable to determine" |
| |
| |
| if lower_answer in ["yes.", "no.", "yes,", "no,"]: |
| return answer[:-1] |
| |
| |
| if answer.endswith(".") and " " not in answer: |
| answer = answer[:-1] |
| |
| return answer |
|
|
| |
| class BasicAgent: |
| def __init__(self): |
| print("Initializing LangGraph Agent...") |
| |
| |
| api_key = os.getenv("ANTHROPIC_API_KEY") |
| |
| if not api_key: |
| print("Warning: ANTHROPIC_API_KEY not found in environment variables.") |
| print("Please set it in the Gradio interface or as an environment variable.") |
| self.agent = None |
| else: |
| try: |
| self.agent = LangGraphAgent(api_key) |
| print("LangGraph Agent initialized successfully.") |
| except Exception as e: |
| print(f"Error initializing LangGraph Agent: {e}") |
| self.agent = None |
| |
| def set_api_key(self, api_key: str): |
| """Set or update the API key.""" |
| if api_key: |
| try: |
| self.agent = LangGraphAgent(api_key) |
| return True |
| except Exception as e: |
| print(f"Error setting API key: {e}") |
| return False |
| return False |
| |
| def __call__(self, question: str) -> str: |
| print(f"\n{'='*60}") |
| print(f"DEBUG: Agent received question") |
| print(f"Question type: {type(question)}") |
| print(f"Question length: {len(question) if isinstance(question, str) else 'N/A'}") |
| print(f"Question preview: {str(question)[:200]}...") |
| print(f"{'='*60}\n") |
| |
| if not self.agent: |
| return "Error: Agent not initialized. Please set your ANTHROPIC_API_KEY." |
| |
| try: |
| answer = self.agent.run(question) |
| print(f"\nDEBUG: Agent generated answer") |
| print(f"Answer type: {type(answer)}") |
| print(f"Answer preview: {str(answer)[:200]}...") |
| return answer |
| except Exception as e: |
| error_msg = f"Error processing question: {str(e)}" |
| print(f"\nDEBUG: Error occurred!") |
| print(f"Error type: {type(e)}") |
| print(f"Error details: {str(e)}") |
| import traceback |
| print(f"Traceback:\n{traceback.format_exc()}") |
| return error_msg |
|
|
| |
| global_agent = None |
|
|
| def validate_api_keys(anthropic_key: str, serpapi_key: str = None, tavily_key: str = None): |
| """Validate the API keys before using them.""" |
| results = [] |
| |
| |
| if anthropic_key: |
| try: |
| test_llm = ChatAnthropic( |
| api_key=anthropic_key, |
| model="claude-3-5-sonnet-20241022", |
| max_tokens=10 |
| ) |
| |
| test_llm.invoke([HumanMessage(content="test")]) |
| results.append("โ
Anthropic API key is valid") |
| except Exception as e: |
| error_msg = str(e) |
| if "401" in error_msg or "authentication" in error_msg.lower(): |
| results.append("โ Anthropic API key is invalid or expired") |
| else: |
| results.append(f"โ Anthropic API error: {error_msg[:100]}...") |
| else: |
| results.append("โ No Anthropic API key provided") |
| |
| |
| if tavily_key: |
| try: |
| import requests |
| test_url = "https://api.tavily.com/search" |
| test_data = { |
| "api_key": tavily_key, |
| "query": "test", |
| "max_results": 1 |
| } |
| response = requests.post(test_url, json=test_data, timeout=5) |
| if response.status_code == 200: |
| results.append("โ
Tavily API key is valid") |
| else: |
| results.append(f"โ Tavily API key error: {response.status_code}") |
| except Exception as e: |
| results.append(f"โ ๏ธ Tavily API test error: {str(e)[:100]}...") |
| else: |
| results.append("โน๏ธ No Tavily API key provided") |
| |
| |
| if serpapi_key: |
| try: |
| params = { |
| "q": "test", |
| "api_key": serpapi_key, |
| "num": 1, |
| "engine": "google" |
| } |
| search = GoogleSearch(params) |
| search.get_dict() |
| results.append("โ
SerpAPI key is valid") |
| except Exception as e: |
| results.append(f"โ ๏ธ SerpAPI key error: {str(e)[:100]}...") |
| else: |
| results.append("โน๏ธ No SerpAPI key provided") |
| |
| return "\n".join(results) |
|
|
| def initialize_agent_with_key(api_key: str): |
| """Initialize the global agent with the provided API key.""" |
| global global_agent |
| |
| |
| validation_result = validate_api_keys(api_key) |
| if "โ Anthropic API key is invalid" in validation_result: |
| return validation_result |
| |
| if api_key: |
| if global_agent is None: |
| global_agent = BasicAgent() |
| success = global_agent.set_api_key(api_key) |
| if success: |
| return f"{validation_result}\n\nโ
Agent initialized successfully!" |
| else: |
| return "โ Failed to initialize agent. Please check if your API key is valid." |
| return "โ Please provide an API key." |
|
|
| def run_and_submit_all(api_key: str, profile: gr.OAuthProfile | None): |
| """ |
| Fetches all questions, runs the BasicAgent on them, submits all answers, |
| and displays the results. |
| """ |
| global global_agent |
| |
| |
| if global_agent is None or api_key: |
| init_msg = initialize_agent_with_key(api_key) |
| print(init_msg) |
| if "Failed" in init_msg or "Please provide" in init_msg: |
| return init_msg, None |
| |
| |
| space_id = os.getenv("SPACE_ID") |
| |
| if profile: |
| username = f"{profile.username}" |
| print(f"User logged in: {username}") |
| else: |
| print("User not logged in.") |
| return "Please Login to Hugging Face with the button.", None |
| |
| api_url = DEFAULT_API_URL |
| questions_url = f"{api_url}/questions" |
| submit_url = f"{api_url}/submit" |
| |
| |
| agent = global_agent |
| if not agent: |
| return "Error: Agent not initialized properly.", None |
| |
| agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main" if space_id else "local" |
| print(f"Agent code URL: {agent_code}") |
| |
| |
| print(f"Fetching questions from: {questions_url}") |
| try: |
| response = requests.get(questions_url, timeout=15) |
| response.raise_for_status() |
| questions_data = response.json() |
| if not questions_data: |
| print("Fetched questions list is empty.") |
| return "Fetched questions list is empty or invalid format.", None |
| print(f"Fetched {len(questions_data)} questions.") |
| except Exception as e: |
| print(f"Error fetching questions: {e}") |
| return f"Error fetching questions: {e}", None |
| |
| |
| results_log = [] |
| answers_payload = [] |
| print(f"Running agent on {len(questions_data)} questions...") |
| |
| for i, item in enumerate(questions_data, 1): |
| task_id = item.get("task_id") |
| question_text = item.get("question") |
| |
| if not task_id or question_text is None: |
| print(f"Skipping item with missing task_id or question: {item}") |
| continue |
| |
| print(f"\nProcessing question {i}/{len(questions_data)}: {task_id}") |
| |
| try: |
| submitted_answer = agent(question_text) |
| answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer}) |
| results_log.append({ |
| "Task ID": task_id, |
| "Question": question_text[:100] + "..." if len(question_text) > 100 else question_text, |
| "Submitted Answer": submitted_answer[:200] + "..." if len(submitted_answer) > 200 else submitted_answer |
| }) |
| except Exception as e: |
| print(f"Error running agent on task {task_id}: {e}") |
| error_answer = f"AGENT ERROR: {e}" |
| answers_payload.append({"task_id": task_id, "submitted_answer": error_answer}) |
| results_log.append({ |
| "Task ID": task_id, |
| "Question": question_text[:100] + "...", |
| "Submitted Answer": error_answer |
| }) |
| |
| if not answers_payload: |
| print("Agent did not produce any answers to submit.") |
| return "Agent did not produce any answers to submit.", pd.DataFrame(results_log) |
| |
| |
| submission_data = { |
| "username": username.strip(), |
| "agent_code": agent_code, |
| "answers": answers_payload |
| } |
| status_update = f"Agent finished. Submitting {len(answers_payload)} answers for user '{username}'..." |
| print(status_update) |
| |
| |
| print(f"Submitting {len(answers_payload)} answers to: {submit_url}") |
| try: |
| response = requests.post(submit_url, json=submission_data, timeout=60) |
| response.raise_for_status() |
| result_data = response.json() |
| final_status = ( |
| f"Submission Successful!\n" |
| f"User: {result_data.get('username')}\n" |
| f"Overall Score: {result_data.get('score', 'N/A')}% " |
| f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n" |
| f"Message: {result_data.get('message', 'No message received.')}" |
| ) |
| print("Submission successful.") |
| results_df = pd.DataFrame(results_log) |
| return final_status, results_df |
| except Exception as e: |
| status_message = f"Submission Failed: {str(e)}" |
| print(status_message) |
| results_df = pd.DataFrame(results_log) |
| return status_message, results_df |
|
|
| |
| with gr.Blocks() as demo: |
| gr.Markdown("# LangGraph Agent for GAIA Evaluation") |
| gr.Markdown( |
| """ |
| **This agent uses LangGraph with multiple tools to answer complex questions:** |
| - ๐ Web Search (Tavily โ DuckDuckGo โ SerpAPI) |
| - ๐งฎ Calculator for mathematical computations |
| - ๐ Python code execution |
| - ๐
Current date/time |
| - ๐ผ๏ธ Image analysis (description-based) |
| |
| **Instructions:** |
| 1. Enter your Anthropic API key (Claude Sonnet 3.5) |
| 2. Optionally enter your Tavily API key for best web search (free tier: 1000/month) |
| 3. Optionally enter your SerpAPI key as backup |
| 4. Log in to your Hugging Face account |
| 5. Click 'Run Evaluation & Submit All Answers' |
| |
| **Search Priority:** Tavily (if key) โ DuckDuckGo (free) โ SerpAPI (if key) |
| """ |
| ) |
| |
| with gr.Row(): |
| with gr.Column(): |
| gr.LoginButton() |
| |
| with gr.Row(): |
| with gr.Column(): |
| api_key_input = gr.Textbox( |
| label="Anthropic API Key (Required)", |
| placeholder="sk-ant-...", |
| type="password" |
| ) |
| tavily_key_input = gr.Textbox( |
| label="Tavily API Key (Recommended for web search)", |
| placeholder="tvly-...", |
| type="password" |
| ) |
| serpapi_key_input = gr.Textbox( |
| label="SerpAPI Key (Optional backup)", |
| placeholder="Your SerpAPI key...", |
| type="password" |
| ) |
| |
| with gr.Row(): |
| validate_button = gr.Button("Validate API Keys", variant="secondary") |
| init_button = gr.Button("Initialize Agent", variant="secondary") |
| run_button = gr.Button("Run Evaluation & Submit All Answers", variant="primary") |
| |
| status_output = gr.Textbox(label="Status / Results", lines=8, interactive=False) |
| results_table = gr.DataFrame(label="Questions and Agent Answers", wrap=True) |
| |
| |
| def set_tavily_key(key): |
| if key: |
| os.environ["TAVILY_API_KEY"] = key |
| return "โ
Tavily API key set!" |
| return "" |
| |
| def set_serpapi_key(key): |
| if key: |
| os.environ["SERPAPI_KEY"] = key |
| return "โ
SerpAPI key set!" |
| return "" |
| |
| tavily_key_input.change(set_tavily_key, inputs=[tavily_key_input], outputs=[]) |
| serpapi_key_input.change(set_serpapi_key, inputs=[serpapi_key_input], outputs=[]) |
| |
| |
| def validate_all_keys(anthropic_key, tavily_key, serpapi_key): |
| if tavily_key: |
| os.environ["TAVILY_API_KEY"] = tavily_key |
| if serpapi_key: |
| os.environ["SERPAPI_KEY"] = serpapi_key |
| return validate_api_keys(anthropic_key, serpapi_key, tavily_key) |
| |
| validate_button.click( |
| fn=validate_all_keys, |
| inputs=[api_key_input, tavily_key_input, serpapi_key_input], |
| outputs=[status_output] |
| ) |
| |
| init_button.click( |
| fn=initialize_agent_with_key, |
| inputs=[api_key_input], |
| outputs=[status_output] |
| ) |
| |
| run_button.click( |
| fn=run_and_submit_all, |
| inputs=[api_key_input], |
| outputs=[status_output, results_table] |
| ) |
|
|
| if __name__ == "__main__": |
| print("\n" + "-"*30 + " App Starting " + "-"*30) |
| print("LangGraph Agent for GAIA Evaluation") |
| print("Required: ANTHROPIC_API_KEY") |
| print("Recommended: TAVILY_API_KEY for best web search (1000 free/month)") |
| print("Optional: SERPAPI_KEY as backup") |
| print("Fallback: DuckDuckGo search (no API key needed)") |
| print("-"*74 + "\n") |
| |
| demo.launch(debug=True, share=False) |