Spaces:
Configuration error
Configuration error
| import streamlit as st | |
| import pandas as pd | |
| import numpy as np | |
| import faiss | |
| import re | |
| import ast | |
| import os | |
| import urllib.request | |
| from sentence_transformers import SentenceTransformer | |
| from sentence_transformers.util import cos_sim | |
| from langchain.chat_models import ChatOpenAI | |
| from langchain.agents import initialize_agent, AgentType, tool | |
| from streamlit_chat import message | |
| # --------------------------- | |
| # Configuration | |
| # --------------------------- | |
| st.set_page_config(page_title="📱 AI Product Search Agent", layout="wide") | |
| # --------------------------- | |
| # Load model | |
| # --------------------------- | |
| def load_model(): | |
| return SentenceTransformer("all-MiniLM-L6-v2") | |
| # --------------------------- | |
| # Load dataset and FAISS index | |
| # --------------------------- | |
| def load_data(): | |
| parquet_url = "https://huggingface.co/datasets/McAuley-Lab/Amazon-Reviews-2023/resolve/main/raw_meta_Cell_Phones_and_Accessories/full-00000-of-00007.parquet" | |
| df = pd.read_parquet(parquet_url) | |
| index_url = "https://huggingface.co/GovinKin/MGTA415database/resolve/main/cellphones_index.faiss" | |
| local_index_path = "cellphones_index.faiss" | |
| if not os.path.exists(local_index_path): | |
| urllib.request.urlretrieve(index_url, local_index_path) | |
| index = faiss.read_index(local_index_path) | |
| return df, index | |
| # --------------------------- | |
| # Search functions | |
| # --------------------------- | |
| def search(query, model, df, index, top_k=10): | |
| query_vector = model.encode([query]).astype("float32") | |
| distances, indices = index.search(query_vector, k=top_k) | |
| results = df.iloc[indices[0]].copy() | |
| results["distance"] = distances[0] | |
| return results | |
| def search_plus(query, model, df, index, top_k=20): | |
| results = search(query, model, df, index, top_k=top_k) | |
| price_match = re.search(r"(under|below)\s*\$?(\d+)", query.lower()) | |
| price_under = float(price_match.group(2)) if price_match else None | |
| if price_under: | |
| try: | |
| results["price"] = results["price"].astype(float) | |
| results = results[results["price"] < price_under] | |
| except: | |
| pass | |
| stop_words = {"i", "want", "need", "the", "a", "for", "with", "to", "is", "it", "on", "of", "buy", "and", "in"} | |
| keywords = [kw for kw in query.lower().split() if kw not in stop_words and len(kw) > 2] | |
| if not results.empty and keywords: | |
| pattern = '|'.join(map(re.escape, keywords)) | |
| results = results[results["title"].str.lower().str.contains(pattern, na=False)] | |
| return results | |
| def rerank_by_similarity(query, results, model, top_n=5): | |
| if results.empty: | |
| return results | |
| query_vec = model.encode([query], convert_to_tensor=True) | |
| titles = results["title"].astype(str).tolist() | |
| title_vecs = model.encode(titles, convert_to_tensor=True) | |
| scores = cos_sim(query_vec, title_vecs)[0].cpu().numpy() | |
| results["similarity"] = scores | |
| return results.sort_values("similarity", ascending=False).head(top_n) | |
| # --------------------------- | |
| # Agent Tool: wraps search_plus | |
| # --------------------------- | |
| def product_search_tool(query: str) -> str: | |
| """Search for cellphone accessories using a natural query.""" | |
| results = search_plus(query, model, df_all, index, top_k=10) | |
| if results.empty: | |
| return "No results found." | |
| return "\n".join(results["title"].head(5).tolist()) | |
| # --------------------------- | |
| # Load all resources | |
| # --------------------------- | |
| model = load_model() | |
| df_all, index = load_data() | |
| # --------------------------- | |
| # Agent setup | |
| # --------------------------- | |
| import os | |
| os.environ["OPENAI_API_KEY"] = st.secrets["openai"]["api_key"] | |
| os.environ["OPENAI_API_BASE"] = st.secrets["openai"].get("base_url", "https://api.openai.com/v1") | |
| llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.3) | |
| agent = initialize_agent( | |
| tools=[product_search_tool], | |
| llm=llm, | |
| agent=AgentType.OPENAI_FUNCTIONS, | |
| verbose=True | |
| ) | |
| # --------------------------- | |
| # Streamlit Chat Interface | |
| # --------------------------- | |
| st.title("🤖 AI Product Search Agent") | |
| st.markdown("Ask natural questions like 'cheap rugged iPhone case under $30'") | |
| if "chat_history" not in st.session_state: | |
| st.session_state.chat_history = [] | |
| user_input = st.chat_input("Ask about cellphone accessories...") | |
| if user_input: | |
| st.session_state.chat_history.append(("user", user_input)) | |
| with st.spinner("Agent is thinking..."): | |
| try: | |
| reply = agent.run(user_input) | |
| except Exception as e: | |
| reply = f"⚠️ Agent error: {e}" | |
| st.session_state.chat_history.append(("agent", reply)) | |
| for role, msg in st.session_state.chat_history: | |
| message(msg, is_user=(role == "user")) | |