| from src.state.state import State |
| from src.tools.websearch import WebSearchTool |
| from langchain_core.messages import HumanMessage, AIMessage, ToolMessage |
| from langchain.tools import BaseTool |
|
|
| class WebSearchChatbot: |
| def __init__(self, model, session_id: str = "default", tavily_api_key: str = None): |
| self.model = model |
| self.session_id = session_id |
| self.memory_config = {"configurable": {"session_id": session_id}} |
| |
| if tavily_api_key and tavily_api_key.strip(): |
| try: |
| self.web_search = WebSearchTool(tavily_api_key) |
| self.tools = [self.web_search.get_tool()] |
| self.model_with_tools = model.bind_tools(self.tools) |
| self.has_search = True |
| except Exception as e: |
| self.model_with_tools = model |
| self.has_search = False |
| else: |
| self.model_with_tools = model |
| self.has_search = False |
|
|
| def process(self, state): |
| messages = state['messages'] |
| if not messages: |
| return state |
| |
| if not self.has_search: |
| |
| last_message = messages[-1] |
| if hasattr(last_message, 'content') and any(keyword in last_message.content.lower() for keyword in ['search', 'find', 'latest', 'current', 'news']): |
| search_disclaimer = "I don't have web search capabilities enabled. Please provide a Tavily API key to search for current information." |
| response_content = f"{search_disclaimer}\n\nBased on my training data, I can still help with general questions." |
| from langchain_core.messages import AIMessage |
| return {'messages': AIMessage(content=response_content)} |
| |
| response = self.model_with_tools.invoke(messages, config=self.memory_config) |
| |
| if hasattr(response, 'tool_calls') and response.tool_calls: |
| messages.append(response) |
| |
| for tool_call in response.tool_calls: |
| tool_result = self._execute_tool_call(tool_call) |
| tool_message = ToolMessage( |
| content=str(tool_result), |
| tool_call_id=tool_call['id'] |
| ) |
| messages.append(tool_message) |
| |
| final_response = self.model_with_tools.invoke(messages, config=self.memory_config) |
| return {'messages': final_response} |
| |
| return {'messages': response} |
| |
| def _execute_tool_call(self, tool_call): |
| if tool_call['name'] == 'tavily_search_results_json': |
| return self.web_search.search_tool.invoke(tool_call['args']) |
| return "Tool not found" |
|
|