Spaces:
Sleeping
Sleeping
Commit ·
348c1c6
1
Parent(s): b1fdd32
improve rag_search
Browse files
agent.py
CHANGED
|
@@ -21,7 +21,7 @@ from langchain.agents import initialize_agent, AgentType
|
|
| 21 |
from langchain_community.retrievers import BM25Retriever
|
| 22 |
from langchain.schema import BaseMessage, SystemMessage, HumanMessage
|
| 23 |
from langgraph.graph.message import add_messages
|
| 24 |
-
from langgraph.graph import START, StateGraph
|
| 25 |
from langgraph.prebuilt import ToolNode, tools_condition
|
| 26 |
from langchain_core.documents import Document
|
| 27 |
|
|
@@ -33,6 +33,7 @@ from dotenv import load_dotenv
|
|
| 33 |
from contextlib import redirect_stdout
|
| 34 |
from langchain_community.tools import TavilySearchResults
|
| 35 |
from tavily import TavilyClient
|
|
|
|
| 36 |
|
| 37 |
# Load environment variables from .env file
|
| 38 |
# in HF Spaces, the .env file is saved in Variables and secrets in settings
|
|
@@ -274,12 +275,29 @@ def download_file(url_or_path: str, save_dir: str = "./downloads") -> str:
|
|
| 274 |
return f"Error downloading/copying file: {e}"
|
| 275 |
|
| 276 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
# Update tools list
|
| 278 |
tools: List[StructuredTool] = [
|
| 279 |
-
calculate,
|
| 280 |
read_pdf, read_csv, read_spreadsheet, transcribe_audio,
|
| 281 |
youtube_transcript_tool, youtube_transcript_api, read_jsonl,
|
| 282 |
-
python_interpreter, download_file,
|
| 283 |
]
|
| 284 |
|
| 285 |
class AgentState(TypedDict):
|
|
@@ -352,6 +370,15 @@ class MyAgent:
|
|
| 352 |
except Exception as e:
|
| 353 |
print(f"Error loading {path}: {e}")
|
| 354 |
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 355 |
|
| 356 |
def build_retriever(self):
|
| 357 |
"""
|
|
@@ -368,18 +395,13 @@ class MyAgent:
|
|
| 368 |
@tool(name="rag_search")
|
| 369 |
def rag_search(query: str) -> str:
|
| 370 |
"""Search loaded documents for relevant information."""
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
for i, doc in enumerate(res[:3]))
|
| 379 |
-
return "No relevant information found in loaded documents."
|
| 380 |
-
except Exception as e:
|
| 381 |
-
return f"Error searching documents: {e}"
|
| 382 |
-
|
| 383 |
# Remove existing rag_search if present to prevent duplicates
|
| 384 |
self.tools = [t for t in self.tools if t.name != "rag_search"]
|
| 385 |
self.tools.append(rag_search)
|
|
@@ -426,10 +448,17 @@ class MyAgent:
|
|
| 426 |
builder.add_edge(START, "assistant")
|
| 427 |
|
| 428 |
# Fix conditional edges with better check
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 429 |
builder.add_conditional_edges(
|
| 430 |
"assistant",
|
| 431 |
-
|
| 432 |
-
"tools"
|
| 433 |
)
|
| 434 |
builder.add_edge("tools", "assistant")
|
| 435 |
|
|
@@ -441,8 +470,10 @@ class MyAgent:
|
|
| 441 |
last_message = out["messages"][-1].content
|
| 442 |
|
| 443 |
# Extract only the FINAL ANSWER part
|
| 444 |
-
|
| 445 |
-
|
|
|
|
|
|
|
| 446 |
return last_message.strip()
|
| 447 |
except Exception as e:
|
| 448 |
return f"Error processing question: {e}"
|
|
|
|
| 21 |
from langchain_community.retrievers import BM25Retriever
|
| 22 |
from langchain.schema import BaseMessage, SystemMessage, HumanMessage
|
| 23 |
from langgraph.graph.message import add_messages
|
| 24 |
+
from langgraph.graph import START, END, StateGraph
|
| 25 |
from langgraph.prebuilt import ToolNode, tools_condition
|
| 26 |
from langchain_core.documents import Document
|
| 27 |
|
|
|
|
| 33 |
from contextlib import redirect_stdout
|
| 34 |
from langchain_community.tools import TavilySearchResults
|
| 35 |
from tavily import TavilyClient
|
| 36 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 37 |
|
| 38 |
# Load environment variables from .env file
|
| 39 |
# in HF Spaces, the .env file is saved in Variables and secrets in settings
|
|
|
|
| 275 |
return f"Error downloading/copying file: {e}"
|
| 276 |
|
| 277 |
|
| 278 |
+
@tool
|
| 279 |
+
def extract_table(file_path: str, query: str = "") -> str:
|
| 280 |
+
"""Extract relevant rows from a CSV or Excel file based on a query."""
|
| 281 |
+
import pandas as pd
|
| 282 |
+
ext = Path(file_path).suffix.lower()
|
| 283 |
+
if ext in [".csv"]:
|
| 284 |
+
df = pd.read_csv(file_path)
|
| 285 |
+
elif ext in [".xlsx", ".xls"]:
|
| 286 |
+
df = pd.read_excel(file_path)
|
| 287 |
+
else:
|
| 288 |
+
return "Unsupported file type."
|
| 289 |
+
# Simple filter: return all if no query, else filter columns containing query
|
| 290 |
+
if query:
|
| 291 |
+
mask = df.apply(lambda row: row.astype(str).str.contains(query, case=False).any(), axis=1)
|
| 292 |
+
df = df[mask]
|
| 293 |
+
return df.head(10).to_csv(index=False)
|
| 294 |
+
|
| 295 |
# Update tools list
|
| 296 |
tools: List[StructuredTool] = [
|
| 297 |
+
calculate, tavily_search, wikipedia_search, image_recognition,
|
| 298 |
read_pdf, read_csv, read_spreadsheet, transcribe_audio,
|
| 299 |
youtube_transcript_tool, youtube_transcript_api, read_jsonl,
|
| 300 |
+
python_interpreter, download_file, extract_table # Add tavily_search here
|
| 301 |
]
|
| 302 |
|
| 303 |
class AgentState(TypedDict):
|
|
|
|
| 370 |
except Exception as e:
|
| 371 |
print(f"Error loading {path}: {e}")
|
| 372 |
continue
|
| 373 |
+
# After loading each doc:
|
| 374 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=100)
|
| 375 |
+
for doc in loaded_docs:
|
| 376 |
+
chunks = text_splitter.split_text(doc.page_content)
|
| 377 |
+
for i, chunk in enumerate(chunks):
|
| 378 |
+
self.docs.append(Document(
|
| 379 |
+
page_content=chunk,
|
| 380 |
+
metadata={**doc.metadata, "chunk": i, "source": path}
|
| 381 |
+
))
|
| 382 |
|
| 383 |
def build_retriever(self):
|
| 384 |
"""
|
|
|
|
| 395 |
@tool(name="rag_search")
|
| 396 |
def rag_search(query: str) -> str:
|
| 397 |
"""Search loaded documents for relevant information."""
|
| 398 |
+
if not self.retriever:
|
| 399 |
+
return "No documents loaded."
|
| 400 |
+
docs = self.retriever.get_relevant_documents(query)
|
| 401 |
+
if not docs:
|
| 402 |
+
return "No relevant information found."
|
| 403 |
+
return "\n\n".join(f"{doc.metadata.get('source', '')}: {doc.page_content[:500]}" for doc in docs[:3])
|
| 404 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 405 |
# Remove existing rag_search if present to prevent duplicates
|
| 406 |
self.tools = [t for t in self.tools if t.name != "rag_search"]
|
| 407 |
self.tools.append(rag_search)
|
|
|
|
| 448 |
builder.add_edge(START, "assistant")
|
| 449 |
|
| 450 |
# Fix conditional edges with better check
|
| 451 |
+
def _should_use_tools(state):
|
| 452 |
+
# If there are loaded docs, always use rag_search first
|
| 453 |
+
if state.get("input_file"):
|
| 454 |
+
return "tools"
|
| 455 |
+
# Otherwise, let the assistant try to answer
|
| 456 |
+
return "assistant"
|
| 457 |
+
|
| 458 |
builder.add_conditional_edges(
|
| 459 |
"assistant",
|
| 460 |
+
_should_use_tools,
|
| 461 |
+
{"tools": "tools", "assistant": END}
|
| 462 |
)
|
| 463 |
builder.add_edge("tools", "assistant")
|
| 464 |
|
|
|
|
| 470 |
last_message = out["messages"][-1].content
|
| 471 |
|
| 472 |
# Extract only the FINAL ANSWER part
|
| 473 |
+
import re
|
| 474 |
+
match = re.search(r"FINAL ANSWER[:\s]*([^\n]*)", last_message, re.IGNORECASE)
|
| 475 |
+
if match:
|
| 476 |
+
return match.group(1).strip()
|
| 477 |
return last_message.strip()
|
| 478 |
except Exception as e:
|
| 479 |
return f"Error processing question: {e}"
|