Spaces:
Running
Running
chore: remove legacy files and scripts no longer part of the main architecture
Browse files- .cursorrules +10 -0
- legacy/README.md +0 -10
- legacy/agent/agent_core.py +0 -55
- legacy/agent/data_loader.py +0 -61
- legacy/agent/dialogue_manager.py +0 -39
- legacy/agent/intent_parser.py +0 -64
- legacy/agent/llm_generator.py +0 -92
- legacy/agent/rag_indexer.py +0 -56
- legacy/agent/rag_retriever.py +0 -42
- legacy/app.py +0 -264
- legacy/deploy.sh +0 -47
- legacy/download_fix.py +0 -11
- scripts/__init__.py +1 -0
- scripts/data/__init__.py +1 -0
- scripts/data/build_books_basic_info.py +46 -46
- scripts/data/clean_data.py +39 -30
- scripts/data/generate_emotions.py +39 -63
- scripts/data/generate_tags.py +27 -17
- scripts/data/split_rec_data.py +68 -125
- scripts/model/build_recall_models.py +29 -44
- scripts/model/train_sasrec.py +26 -186
- scripts/model/train_youtube_dnn.py +27 -221
- scripts/run_pipeline.py +190 -138
- src/core/model_loader.py +1 -1
- src/data/__init__.py +5 -0
- src/data/repository.py +94 -0
- src/init_db.py +2 -13
- src/main.py +0 -5
- src/marketing/persona.py +7 -5
- src/marketing/personalized_highlight.py +0 -1
- src/marketing/verify_p3.py +1 -1
- src/recall/embedding.py +161 -4
- src/recall/fusion.py +79 -43
- src/recall/itemcf.py +132 -21
- src/recall/popularity.py +0 -1
- src/recall/sasrec_recall.py +158 -3
- src/recall/sequence_utils.py +48 -0
- src/recommender.py +1 -5
- src/services/chat_service.py +12 -31
- src/vector_db.py +0 -3
.cursorrules
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
You are a Senior Python Systems Architect specializing in Machine Learning Engineering.
|
| 2 |
+
Your goal is to refactor a research-prototype code base into a production-grade system.
|
| 3 |
+
|
| 4 |
+
Guidelines:
|
| 5 |
+
1. **Code Structure**: Follow Clean Architecture principles. Separate concerns strictly between Data Access, Business Logic, and Interface.
|
| 6 |
+
2. **Type Hinting**: All new or refactored functions MUST have Python type hints and docstrings.
|
| 7 |
+
3. **No "Glue Scripts"**: Avoid using `subprocess.run` to call other Python scripts. Import classes and call methods instead.
|
| 8 |
+
4. **Error Handling**: Use specific exception handling, not bare `except Exception`.
|
| 9 |
+
5. **Paths**: Always use `pathlib.Path`, never `os.path.join`.
|
| 10 |
+
6. **Refactoring Safety**: When refactoring, ensure existing logic (feature engineering, recall calculation) is preserved unless explicitly asked to simplify.
|
legacy/README.md
DELETED
|
@@ -1,10 +0,0 @@
|
|
| 1 |
-
# Legacy — Not part of main architecture
|
| 2 |
-
|
| 3 |
-
Code moved here is preserved but not used in the main flow (src.main FastAPI + React).
|
| 4 |
-
|
| 5 |
-
| File | Note |
|
| 6 |
-
|:---|:---|
|
| 7 |
-
| app.py | Gradio UI (replaced by React + FastAPI) |
|
| 8 |
-
| agent/ | Shopping agent (broken imports, not used) |
|
| 9 |
-
| deploy.sh | Old Hugging Face deployment script |
|
| 10 |
-
| download_fix.py | Temporary fix script |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
legacy/agent/agent_core.py
DELETED
|
@@ -1,55 +0,0 @@
|
|
| 1 |
-
from intent_parser import IntentParser
|
| 2 |
-
from rag_retriever import ProductRetriever
|
| 3 |
-
from dialogue_manager import DialogueManager
|
| 4 |
-
from llm_generator import LLMGenerator
|
| 5 |
-
import os
|
| 6 |
-
|
| 7 |
-
class ShoppingAgent:
|
| 8 |
-
def __init__(self, index_path: str, metadata_path: str, llm_model: str = None):
|
| 9 |
-
self.parser = IntentParser()
|
| 10 |
-
self.retriever = ProductRetriever(index_path, metadata_path)
|
| 11 |
-
self.dialogue_manager = DialogueManager()
|
| 12 |
-
self.llm = LLMGenerator(model_name=llm_model) # Defaults to mock
|
| 13 |
-
|
| 14 |
-
def process_query(self, query: str):
|
| 15 |
-
print(f"\nUser: {query}")
|
| 16 |
-
|
| 17 |
-
# 1. Parse Intent
|
| 18 |
-
intent = self.parser.parse(query)
|
| 19 |
-
# print(f"[Debug] Intent: {intent}")
|
| 20 |
-
|
| 21 |
-
# 2. Enrich Query (incorporating history could happen here)
|
| 22 |
-
search_query = query
|
| 23 |
-
if intent['category']:
|
| 24 |
-
search_query += f" {intent['category']}"
|
| 25 |
-
|
| 26 |
-
# 3. Retrieve
|
| 27 |
-
results = self.retriever.search(search_query, k=3)
|
| 28 |
-
|
| 29 |
-
# 4. Generate Response using LLM + History
|
| 30 |
-
history_str = self.dialogue_manager.get_context_string()
|
| 31 |
-
response = self.llm.generate_response(query, results, history_str)
|
| 32 |
-
|
| 33 |
-
# 5. Update Memory
|
| 34 |
-
self.dialogue_manager.add_turn(query, response)
|
| 35 |
-
|
| 36 |
-
print("[Agent]:")
|
| 37 |
-
print(response)
|
| 38 |
-
return response
|
| 39 |
-
|
| 40 |
-
def reset(self):
|
| 41 |
-
self.dialogue_manager.clear_history()
|
| 42 |
-
|
| 43 |
-
if __name__ == "__main__":
|
| 44 |
-
if not os.path.exists("data/product_index.faiss"):
|
| 45 |
-
print("Index not found. Please run rag_indexer.py first.")
|
| 46 |
-
else:
|
| 47 |
-
# Pass "mock" to force CPU-friendly mock generation,
|
| 48 |
-
# or pass a model name like "gpt2" (small) if you have 'transformers' installed to test pipeline.
|
| 49 |
-
agent = ShoppingAgent("data/product_index.faiss", "data/product_metadata.pkl", llm_model="mock")
|
| 50 |
-
|
| 51 |
-
print("--- Turn 1 ---")
|
| 52 |
-
agent.process_query("I need a gaming laptop under $1000")
|
| 53 |
-
|
| 54 |
-
print("\n--- Turn 2 ---")
|
| 55 |
-
agent.process_query("Do you have anything cheaper?")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
legacy/agent/data_loader.py
DELETED
|
@@ -1,61 +0,0 @@
|
|
| 1 |
-
import pandas as pd
|
| 2 |
-
import random
|
| 3 |
-
|
| 4 |
-
def generate_synthetic_data(num_samples: int = 100) -> pd.DataFrame:
|
| 5 |
-
"""
|
| 6 |
-
Generates synthetic e-commerce product data.
|
| 7 |
-
"""
|
| 8 |
-
categories = ['Electronics', 'Clothing', 'Home & Kitchen', 'Books', 'Toys']
|
| 9 |
-
adjectives = ['Premium', 'Budget', 'High-end', 'Durable', 'Stylish', 'Compact', 'Professional']
|
| 10 |
-
products_map = {
|
| 11 |
-
'Electronics': ['Smartphone', 'Laptop', 'Headphones', 'Smartwatch', 'Camera'],
|
| 12 |
-
'Clothing': ['T-Shirt', 'Jeans', 'Jacket', 'Sneakers', 'Dress'],
|
| 13 |
-
'Home & Kitchen': ['Blender', 'Coffee Maker', 'Desk Lamp', 'Sofa', 'Curtains'],
|
| 14 |
-
'Books': ['Novel', 'Textbook', 'Biography', 'Cookbook', 'Comic'],
|
| 15 |
-
'Toys': ['Lego Set', 'Action Figure', 'Board Game', 'Puzzle', 'Doll']
|
| 16 |
-
}
|
| 17 |
-
|
| 18 |
-
data = []
|
| 19 |
-
for i in range(num_samples):
|
| 20 |
-
cat = random.choice(categories)
|
| 21 |
-
prod = random.choice(products_map[cat])
|
| 22 |
-
adj = random.choice(adjectives)
|
| 23 |
-
|
| 24 |
-
title = f"{adj} {prod} {i+1}"
|
| 25 |
-
price = round(random.uniform(10.0, 1000.0), 2)
|
| 26 |
-
description = f"This is a {adj.lower()} {prod.lower()} perfect for your needs. It features high quality materials and modern design."
|
| 27 |
-
features = f"Feature A, Feature B, {adj} Quality"
|
| 28 |
-
|
| 29 |
-
data.append({
|
| 30 |
-
'product_id': f"P{str(i).zfill(4)}",
|
| 31 |
-
'title': title,
|
| 32 |
-
'category': cat,
|
| 33 |
-
'price': price,
|
| 34 |
-
'description': description,
|
| 35 |
-
'features': features,
|
| 36 |
-
'review_text': f"Great {prod}! I loved the {adj.lower()} aspect."
|
| 37 |
-
})
|
| 38 |
-
|
| 39 |
-
return pd.DataFrame(data)
|
| 40 |
-
|
| 41 |
-
def load_data(file_path: str = None) -> pd.DataFrame:
|
| 42 |
-
"""
|
| 43 |
-
Loads data from a file or generates synthetic data if path is None.
|
| 44 |
-
"""
|
| 45 |
-
if file_path:
|
| 46 |
-
# Check extension and load accordingly
|
| 47 |
-
if file_path.endswith('.csv'):
|
| 48 |
-
return pd.read_csv(file_path)
|
| 49 |
-
elif file_path.endswith('.json'):
|
| 50 |
-
return pd.read_json(file_path)
|
| 51 |
-
else:
|
| 52 |
-
raise ValueError("Unsupported file format")
|
| 53 |
-
else:
|
| 54 |
-
print("No file path provided. Generating synthetic data...")
|
| 55 |
-
return generate_synthetic_data()
|
| 56 |
-
|
| 57 |
-
if __name__ == "__main__":
|
| 58 |
-
df = load_data()
|
| 59 |
-
print(df.head())
|
| 60 |
-
df.to_csv("synthetic_products.csv", index=False)
|
| 61 |
-
print("Saved synthetic_products.csv")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
legacy/agent/dialogue_manager.py
DELETED
|
@@ -1,39 +0,0 @@
|
|
| 1 |
-
from typing import List, Dict
|
| 2 |
-
|
| 3 |
-
class DialogueManager:
|
| 4 |
-
def __init__(self, max_history: int = 5):
|
| 5 |
-
self.history: List[Dict[str, str]] = []
|
| 6 |
-
self.max_history = max_history
|
| 7 |
-
|
| 8 |
-
def add_turn(self, user_input: str, system_response: str):
|
| 9 |
-
"""
|
| 10 |
-
Adds a single turn to the history.
|
| 11 |
-
"""
|
| 12 |
-
self.history.append({"role": "user", "content": user_input})
|
| 13 |
-
self.history.append({"role": "assistant", "content": system_response})
|
| 14 |
-
|
| 15 |
-
# Keep history within limit (rolling buffer)
|
| 16 |
-
if len(self.history) > self.max_history * 2:
|
| 17 |
-
self.history = self.history[-(self.max_history * 2):]
|
| 18 |
-
|
| 19 |
-
def get_history(self) -> List[Dict[str, str]]:
|
| 20 |
-
"""
|
| 21 |
-
Returns the conversation history.
|
| 22 |
-
"""
|
| 23 |
-
return self.history
|
| 24 |
-
|
| 25 |
-
def clear_history(self):
|
| 26 |
-
"""
|
| 27 |
-
Resets the conversation.
|
| 28 |
-
"""
|
| 29 |
-
self.history = []
|
| 30 |
-
|
| 31 |
-
def get_context_string(self) -> str:
|
| 32 |
-
"""
|
| 33 |
-
Returns history formatted as a string for simple prompts.
|
| 34 |
-
"""
|
| 35 |
-
context = ""
|
| 36 |
-
for turn in self.history:
|
| 37 |
-
role = "User" if turn["role"] == "user" else "Agent"
|
| 38 |
-
context += f"{role}: {turn['content']}\n"
|
| 39 |
-
return context
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
legacy/agent/intent_parser.py
DELETED
|
@@ -1,64 +0,0 @@
|
|
| 1 |
-
import re
|
| 2 |
-
from typing import Dict, Optional
|
| 3 |
-
|
| 4 |
-
class IntentParser:
|
| 5 |
-
def __init__(self):
|
| 6 |
-
# In a real scenario, this would be an LLM-based parser
|
| 7 |
-
pass
|
| 8 |
-
|
| 9 |
-
def parse(self, query: str) -> Dict[str, Optional[str]]:
|
| 10 |
-
"""
|
| 11 |
-
Parses the user query into structured slots.
|
| 12 |
-
"""
|
| 13 |
-
query = query.lower()
|
| 14 |
-
|
| 15 |
-
intent = {
|
| 16 |
-
'category': None,
|
| 17 |
-
'budget': None,
|
| 18 |
-
'style': None,
|
| 19 |
-
'original_query': query
|
| 20 |
-
}
|
| 21 |
-
|
| 22 |
-
# Rule-based Category Extraction
|
| 23 |
-
categories = ['laptop', 'phone', 'smartphone', 'headphone', 'camera', 'jeans', 'shirt', 'dress', 'shoe', 'blender', 'coffee', 'lamp', 'sofa', 'desk', 'toy', 'lego', 'book', 'novel']
|
| 24 |
-
for cat in categories:
|
| 25 |
-
if cat in query:
|
| 26 |
-
intent['category'] = cat
|
| 27 |
-
break # Take the first match for now
|
| 28 |
-
|
| 29 |
-
# Rule-based Budget Extraction
|
| 30 |
-
# Look for "under $100", "cheap", "expensive", "budget"
|
| 31 |
-
if "cheap" in query or "budget" in query:
|
| 32 |
-
intent['budget'] = "low"
|
| 33 |
-
elif "expensive" in query or "premium" in query:
|
| 34 |
-
intent['budget'] = "high"
|
| 35 |
-
|
| 36 |
-
match = re.search(r'under \$?(\d+)', query)
|
| 37 |
-
if match:
|
| 38 |
-
intent['budget'] = f"<{match.group(1)}"
|
| 39 |
-
|
| 40 |
-
# Rule-based Style/Feature Extraction (naïve)
|
| 41 |
-
# Everything else that is an adjective could be style
|
| 42 |
-
styles = ['gaming', 'professional', 'casual', 'formal', 'black', 'red', 'blue', 'wireless', 'bluetooth']
|
| 43 |
-
found_styles = []
|
| 44 |
-
for style in styles:
|
| 45 |
-
if style in query:
|
| 46 |
-
found_styles.append(style)
|
| 47 |
-
|
| 48 |
-
if found_styles:
|
| 49 |
-
intent['style'] = ", ".join(found_styles)
|
| 50 |
-
|
| 51 |
-
return intent
|
| 52 |
-
|
| 53 |
-
if __name__ == "__main__":
|
| 54 |
-
parser = IntentParser()
|
| 55 |
-
queries = [
|
| 56 |
-
"I want a cheap gaming laptop",
|
| 57 |
-
"Looking for a blue dress under $50",
|
| 58 |
-
"wireless headphones for travel"
|
| 59 |
-
]
|
| 60 |
-
|
| 61 |
-
for q in queries:
|
| 62 |
-
print(f"Query: {q}")
|
| 63 |
-
print(f"Parsed: {parser.parse(q)}")
|
| 64 |
-
print("-" * 20)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
legacy/agent/llm_generator.py
DELETED
|
@@ -1,92 +0,0 @@
|
|
| 1 |
-
from typing import List, Dict, Optional
|
| 2 |
-
import os
|
| 3 |
-
|
| 4 |
-
class LLMGenerator:
|
| 5 |
-
def __init__(self, model_name: str = None, device: str = "cpu"):
|
| 6 |
-
"""
|
| 7 |
-
Initialize LLM.
|
| 8 |
-
Args:
|
| 9 |
-
model_name: HuggingFace model name (e.g., 'meta-llama/Meta-Llama-3-8B-Instruct').
|
| 10 |
-
If None, uses a Mock generator.
|
| 11 |
-
device: 'cpu' or 'cuda'.
|
| 12 |
-
"""
|
| 13 |
-
self.model_name = model_name
|
| 14 |
-
self.device = device
|
| 15 |
-
self.pipeline = None
|
| 16 |
-
|
| 17 |
-
if self.model_name and self.model_name != "mock":
|
| 18 |
-
try:
|
| 19 |
-
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
|
| 20 |
-
import torch
|
| 21 |
-
|
| 22 |
-
print(f"Loading LLM: {model_name} on {device}...")
|
| 23 |
-
# Note: In a real script, we would handle quantization (bitsandbytes) here
|
| 24 |
-
# based on the device capabilities we discussed.
|
| 25 |
-
dtype = torch.float16 if device == 'cuda' else torch.float32
|
| 26 |
-
|
| 27 |
-
self.pipeline = pipeline(
|
| 28 |
-
"text-generation",
|
| 29 |
-
model=model_name,
|
| 30 |
-
torch_dtype=dtype,
|
| 31 |
-
device_map="auto" if device == 'cuda' else "cpu"
|
| 32 |
-
)
|
| 33 |
-
except Exception as e:
|
| 34 |
-
print(f"Failed to load model {model_name}: {e}")
|
| 35 |
-
print("Falling back to Mock Generator.")
|
| 36 |
-
self.model_name = "mock"
|
| 37 |
-
|
| 38 |
-
def generate_response(self, user_query: str, retrieved_items: List[Dict], history_str: str) -> str:
|
| 39 |
-
"""
|
| 40 |
-
Generates a natural language response based on context.
|
| 41 |
-
"""
|
| 42 |
-
# 1. Format retrieved items
|
| 43 |
-
items_str = ""
|
| 44 |
-
for i, item in enumerate(retrieved_items):
|
| 45 |
-
items_str += f"{i+1}. {item['title']} (${item['price']}): {item['description']}\n"
|
| 46 |
-
|
| 47 |
-
# 2. Construct Prompt (Simple Template)
|
| 48 |
-
prompt = f"""You are a helpful shopping assistant.
|
| 49 |
-
|
| 50 |
-
Context History:
|
| 51 |
-
{history_str}
|
| 52 |
-
|
| 53 |
-
Retrieved Products related to the user's request:
|
| 54 |
-
{items_str}
|
| 55 |
-
|
| 56 |
-
User's Query: {user_query}
|
| 57 |
-
|
| 58 |
-
Instructions:
|
| 59 |
-
- Recommend the best products from the list above.
|
| 60 |
-
- Explain WHY they fit the user's request (budget, style, category).
|
| 61 |
-
- Be concise and friendly.
|
| 62 |
-
|
| 63 |
-
Response:"""
|
| 64 |
-
|
| 65 |
-
if self.model_name == "mock" or self.model_name is None:
|
| 66 |
-
return self._mock_generation(items_str)
|
| 67 |
-
else:
|
| 68 |
-
# Real LLM Generation
|
| 69 |
-
try:
|
| 70 |
-
outputs = self.pipeline(
|
| 71 |
-
prompt,
|
| 72 |
-
max_new_tokens=200,
|
| 73 |
-
do_sample=True,
|
| 74 |
-
temperature=0.7,
|
| 75 |
-
truncation=True
|
| 76 |
-
)
|
| 77 |
-
generated_text = outputs[0]['generated_text']
|
| 78 |
-
# Extract only the response part if the model echos the prompt (common in base pipelines)
|
| 79 |
-
if "Response:" in generated_text:
|
| 80 |
-
return generated_text.split("Response:")[-1].strip()
|
| 81 |
-
return generated_text
|
| 82 |
-
except Exception as e:
|
| 83 |
-
return f"[Error generating response: {e}]"
|
| 84 |
-
|
| 85 |
-
def _mock_generation(self, items_str):
|
| 86 |
-
"""
|
| 87 |
-
Fallback logic for testing without a GPU.
|
| 88 |
-
"""
|
| 89 |
-
if not items_str:
|
| 90 |
-
return "I couldn't find any products matching your specific criteria. Could you try different keywords?"
|
| 91 |
-
|
| 92 |
-
return f"Based on your request, I found these great options:\n{items_str}\nI recommend checking the first one as it offers the best value!"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
legacy/agent/rag_indexer.py
DELETED
|
@@ -1,56 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import faiss
|
| 3 |
-
import numpy as np
|
| 4 |
-
import pickle
|
| 5 |
-
from sentence_transformers import SentenceTransformer
|
| 6 |
-
import pandas as pd
|
| 7 |
-
try:
|
| 8 |
-
from src.data_loader import load_data
|
| 9 |
-
except ImportError:
|
| 10 |
-
from data_loader import load_data # Fallback for direct execution
|
| 11 |
-
|
| 12 |
-
class RAGIndexer:
|
| 13 |
-
def __init__(self, model_name: str = 'all-MiniLM-L6-v2'):
|
| 14 |
-
self.model = SentenceTransformer(model_name)
|
| 15 |
-
self.index = None
|
| 16 |
-
self.metadata = []
|
| 17 |
-
|
| 18 |
-
def build_index(self, data: pd.DataFrame):
|
| 19 |
-
"""
|
| 20 |
-
Builds the Faiss index from the product dataframe.
|
| 21 |
-
"""
|
| 22 |
-
print("Encoding product data...")
|
| 23 |
-
# Create a rich text representation for embedding
|
| 24 |
-
# Title + Description + Features + Category + Price (as text)
|
| 25 |
-
documents = data.apply(lambda x: f"{x['title']} {x['description']} Category: {x['category']} Price: {x['price']}", axis=1).tolist()
|
| 26 |
-
|
| 27 |
-
embeddings = self.model.encode(documents, show_progress_bar=True)
|
| 28 |
-
dimension = embeddings.shape[1]
|
| 29 |
-
|
| 30 |
-
self.index = faiss.IndexFlatL2(dimension)
|
| 31 |
-
self.index.add(embeddings.astype('float32'))
|
| 32 |
-
|
| 33 |
-
self.metadata = data.to_dict('records')
|
| 34 |
-
print(f"Index built with {len(self.metadata)} items.")
|
| 35 |
-
|
| 36 |
-
def save(self, index_path: str, metadata_path: str):
|
| 37 |
-
"""
|
| 38 |
-
Saves the index and metadata to disk.
|
| 39 |
-
"""
|
| 40 |
-
if self.index:
|
| 41 |
-
faiss.write_index(self.index, index_path)
|
| 42 |
-
with open(metadata_path, 'wb') as f:
|
| 43 |
-
pickle.dump(self.metadata, f)
|
| 44 |
-
print(f"Saved index to {index_path} and metadata to {metadata_path}")
|
| 45 |
-
else:
|
| 46 |
-
print("No index to save.")
|
| 47 |
-
|
| 48 |
-
if __name__ == "__main__":
|
| 49 |
-
# Scaffolding run
|
| 50 |
-
df = load_data() # Generates synthetic
|
| 51 |
-
indexer = RAGIndexer()
|
| 52 |
-
indexer.build_index(df)
|
| 53 |
-
|
| 54 |
-
# Ensure output dir exists
|
| 55 |
-
os.makedirs("data", exist_ok=True)
|
| 56 |
-
indexer.save("data/product_index.faiss", "data/product_metadata.pkl")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
legacy/agent/rag_retriever.py
DELETED
|
@@ -1,42 +0,0 @@
|
|
| 1 |
-
import faiss
|
| 2 |
-
import pickle
|
| 3 |
-
import numpy as np
|
| 4 |
-
from sentence_transformers import SentenceTransformer
|
| 5 |
-
from typing import List, Dict
|
| 6 |
-
|
| 7 |
-
class ProductRetriever:
|
| 8 |
-
def __init__(self, index_path: str, metadata_path: str, model_name: str = 'all-MiniLM-L6-v2'):
|
| 9 |
-
self.model = SentenceTransformer(model_name)
|
| 10 |
-
|
| 11 |
-
print(f"Loading index from {index_path}...")
|
| 12 |
-
self.index = faiss.read_index(index_path)
|
| 13 |
-
|
| 14 |
-
print(f"Loading metadata from {metadata_path}...")
|
| 15 |
-
with open(metadata_path, 'rb') as f:
|
| 16 |
-
self.metadata = pickle.load(f)
|
| 17 |
-
|
| 18 |
-
def search(self, query: str, k: int = 5) -> List[Dict]:
|
| 19 |
-
"""
|
| 20 |
-
Searches for the top-k most relevant products.
|
| 21 |
-
"""
|
| 22 |
-
query_vector = self.model.encode([query]).astype('float32')
|
| 23 |
-
distances, indices = self.index.search(query_vector, k)
|
| 24 |
-
|
| 25 |
-
results = []
|
| 26 |
-
for i, idx in enumerate(indices[0]):
|
| 27 |
-
if idx < len(self.metadata):
|
| 28 |
-
item = self.metadata[idx]
|
| 29 |
-
item['score'] = float(distances[0][i])
|
| 30 |
-
results.append(item)
|
| 31 |
-
|
| 32 |
-
return results
|
| 33 |
-
|
| 34 |
-
if __name__ == "__main__":
|
| 35 |
-
# Test run
|
| 36 |
-
retriever = ProductRetriever("data/product_index.faiss", "data/product_metadata.pkl")
|
| 37 |
-
query = "cheap gaming laptop"
|
| 38 |
-
results = retriever.search(query)
|
| 39 |
-
|
| 40 |
-
print(f"Query: {query}")
|
| 41 |
-
for res in results:
|
| 42 |
-
print(f" - {res['title']} (${res['price']}) [Score: {res['score']:.4f}]")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
legacy/app.py
DELETED
|
@@ -1,264 +0,0 @@
|
|
| 1 |
-
import gradio as gr
|
| 2 |
-
import logging
|
| 3 |
-
import os
|
| 4 |
-
import requests
|
| 5 |
-
import json
|
| 6 |
-
from typing import List, Tuple, Any
|
| 7 |
-
from src.utils import setup_logger
|
| 8 |
-
|
| 9 |
-
# --- Configuration ---
|
| 10 |
-
API_URL = os.getenv("API_URL", "http://localhost:6006") # Localhost via SSH Tunnel
|
| 11 |
-
|
| 12 |
-
# --- Initialize Logger ---
|
| 13 |
-
logger = setup_logger(__name__)
|
| 14 |
-
|
| 15 |
-
# --- Module Initialization ---
|
| 16 |
-
# (We no longer load model locally; we query the remote API)
|
| 17 |
-
categories = ["All", "Fiction", "History", "Science", "Technology"] # Fallback/Mock for now
|
| 18 |
-
tones = ["All", "Happy", "Surprising", "Angry", "Suspenseful", "Sad"]
|
| 19 |
-
|
| 20 |
-
def fetch_tones():
|
| 21 |
-
try:
|
| 22 |
-
resp = requests.get(f"{API_URL}/tones", timeout=3)
|
| 23 |
-
if resp.status_code == 200:
|
| 24 |
-
data = resp.json()
|
| 25 |
-
tns = data.get("tones") if isinstance(data, dict) else None
|
| 26 |
-
if isinstance(tns, list) and len(tns) > 0:
|
| 27 |
-
return tns
|
| 28 |
-
except Exception as e:
|
| 29 |
-
logger.warning(f"fetch_tones failed: {e}")
|
| 30 |
-
return tones
|
| 31 |
-
|
| 32 |
-
def fetch_categories():
|
| 33 |
-
try:
|
| 34 |
-
resp = requests.get(f"{API_URL}/categories", timeout=3)
|
| 35 |
-
if resp.status_code == 200:
|
| 36 |
-
data = resp.json()
|
| 37 |
-
cats = data.get("categories") if isinstance(data, dict) else None
|
| 38 |
-
if isinstance(cats, list) and len(cats) > 0:
|
| 39 |
-
return cats
|
| 40 |
-
except Exception as e:
|
| 41 |
-
logger.warning(f"fetch_categories failed: {e}")
|
| 42 |
-
return categories
|
| 43 |
-
|
| 44 |
-
# Try to fetch real categories on startup
|
| 45 |
-
categories = fetch_categories()
|
| 46 |
-
tones = fetch_tones()
|
| 47 |
-
|
| 48 |
-
# Initialize Shopping Agent (Mock or Real)
|
| 49 |
-
# Note: Real agent requires FAISS index. We'll handle checks later.
|
| 50 |
-
try:
|
| 51 |
-
# from legacy.agent.agent_core import ShoppingAgent
|
| 52 |
-
# shopping_agent = ShoppingAgent(...)
|
| 53 |
-
pass
|
| 54 |
-
except ImportError:
|
| 55 |
-
logger.warning("Shopping Agent module not found or failed to import.")
|
| 56 |
-
|
| 57 |
-
# --- Business Logic: Tab 1 (Discovery) ---
|
| 58 |
-
def recommend_books(query: str, category: str, tone: str):
|
| 59 |
-
"""Fetch recommendations and return both gallery items and raw data."""
|
| 60 |
-
try:
|
| 61 |
-
if not query.strip():
|
| 62 |
-
return [], []
|
| 63 |
-
|
| 64 |
-
payload = {
|
| 65 |
-
"query": query,
|
| 66 |
-
"category": category if category else "All",
|
| 67 |
-
"tone": tone if tone else "All"
|
| 68 |
-
}
|
| 69 |
-
|
| 70 |
-
logger.info(f"Sending request to {API_URL}/recommend")
|
| 71 |
-
response = requests.post(f"{API_URL}/recommend", json=payload, timeout=25)
|
| 72 |
-
|
| 73 |
-
if response.status_code == 200:
|
| 74 |
-
data = response.json()
|
| 75 |
-
results = data.get("recommendations", [])
|
| 76 |
-
gallery_items = [(item["thumbnail"], f"{item['title']}\n{item['authors']}") for item in results]
|
| 77 |
-
return gallery_items, results
|
| 78 |
-
else:
|
| 79 |
-
logger.error(f"API Error: {response.text}")
|
| 80 |
-
return [], []
|
| 81 |
-
|
| 82 |
-
except Exception as e:
|
| 83 |
-
logger.error(f"Error in recommend_books: {e}")
|
| 84 |
-
return [], []
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
def show_book_details(evt: Any, recs: List[dict]):
|
| 88 |
-
"""Populate detail panel when a gallery item is selected and prep a QA hint."""
|
| 89 |
-
try:
|
| 90 |
-
if recs is None:
|
| 91 |
-
return "", "", "", "", "", -1
|
| 92 |
-
idx = evt.index if evt and hasattr(evt, "index") else None
|
| 93 |
-
if idx is None or idx >= len(recs):
|
| 94 |
-
return "", "", "", "", "", -1
|
| 95 |
-
book = recs[idx]
|
| 96 |
-
title_block = f"### {book['title']}\n**Authors:** {book['authors']}\n**ISBN:** {book['isbn']}"
|
| 97 |
-
desc_block = f"**Description**\n\n{book['description']}"
|
| 98 |
-
rank_block = f"**Rank:** #{idx + 1}" # simple positional rank
|
| 99 |
-
comments_block = "**Reviews (sample):**\n- Exceptional pacing and character depth.\n- A must-read for this genre."
|
| 100 |
-
qa_hint = f"Ask the assistant: Tell me more about '{book['title']}' by {book['authors']}."
|
| 101 |
-
return title_block, rank_block, comments_block, desc_block, qa_hint, idx
|
| 102 |
-
except Exception as e:
|
| 103 |
-
logger.error(f"Error showing book details: {e}")
|
| 104 |
-
return "", "", "", "", "", -1
|
| 105 |
-
|
| 106 |
-
def clear_discovery():
|
| 107 |
-
return "", "All", "All", []
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
def add_to_favorites(selected_idx: int, recs: List[dict]):
|
| 111 |
-
try:
|
| 112 |
-
if selected_idx is None or selected_idx < 0 or not recs or selected_idx >= len(recs):
|
| 113 |
-
return "Please select a book from the gallery first."
|
| 114 |
-
book = recs[selected_idx]
|
| 115 |
-
payload = {"user_id": "local", "isbn": book["isbn"]}
|
| 116 |
-
resp = requests.post(f"{API_URL}/favorites/add", json=payload, timeout=8)
|
| 117 |
-
if resp.status_code == 200:
|
| 118 |
-
data = resp.json()
|
| 119 |
-
return f"✅ Added to favorites: {book['title']} ({data.get('favorites_count', '?')} books in collection)"
|
| 120 |
-
return f"❌ Failed to add: {resp.text}"
|
| 121 |
-
except Exception as e:
|
| 122 |
-
logger.error(f"add_to_favorites error: {e}")
|
| 123 |
-
return "❌ Error adding to favorites. Try again later."
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
def generate_highlights(selected_idx: int, recs: List[dict]):
|
| 127 |
-
try:
|
| 128 |
-
if selected_idx is None or selected_idx < 0 or not recs or selected_idx >= len(recs):
|
| 129 |
-
return "(Hint) Please select a book from the gallery, then click Generate Highlights."
|
| 130 |
-
book = recs[selected_idx]
|
| 131 |
-
payload = {"isbn": book["isbn"], "user_id": "local"}
|
| 132 |
-
resp = requests.post(f"{API_URL}/marketing/highlights", json=payload, timeout=12)
|
| 133 |
-
if resp.status_code != 200:
|
| 134 |
-
return "Failed to generate highlights. Try again later."
|
| 135 |
-
data = resp.json()
|
| 136 |
-
persona = data.get("persona", {})
|
| 137 |
-
highlights = data.get("highlights", [])
|
| 138 |
-
header = f"### Personalized Highlights ({book['title']})\n"
|
| 139 |
-
persona_md = f"> Your Profile: {persona.get('summary','N/A')}\n\n" if persona else ""
|
| 140 |
-
bullets = "\n".join([f"- {h}" for h in highlights]) if highlights else "- No highlights available"
|
| 141 |
-
return header + persona_md + bullets
|
| 142 |
-
except Exception as e:
|
| 143 |
-
logger.error(f"generate_highlights error: {e}")
|
| 144 |
-
return "Error generating highlights. Try again later."
|
| 145 |
-
|
| 146 |
-
# --- Business Logic: Tab 2 (Assistant) ---
|
| 147 |
-
def chat_response(message, history):
|
| 148 |
-
"""Answer book questions using the recommender API as a knowledge source."""
|
| 149 |
-
try:
|
| 150 |
-
if not message.strip():
|
| 151 |
-
return "Please describe the book or question you have."
|
| 152 |
-
|
| 153 |
-
# Use the same recommend endpoint as retrieval to ground answers
|
| 154 |
-
payload = {"query": message, "category": "All", "tone": "All"}
|
| 155 |
-
resp = requests.post(f"{API_URL}/recommend", json=payload, timeout=20)
|
| 156 |
-
if resp.status_code != 200:
|
| 157 |
-
return "Unable to retrieve book information. Try again later."
|
| 158 |
-
|
| 159 |
-
data = resp.json()
|
| 160 |
-
recs = data.get("recommendations", [])
|
| 161 |
-
if not recs:
|
| 162 |
-
return "No matching books found. Try a different query."
|
| 163 |
-
|
| 164 |
-
top = recs[0]
|
| 165 |
-
answer = [
|
| 166 |
-
f"**{top.get('title','')}**",
|
| 167 |
-
f"Author: {top.get('authors','Unknown')}",
|
| 168 |
-
f"Summary: {top.get('description','No summary available')}"
|
| 169 |
-
]
|
| 170 |
-
# If more results, suggest to check discovery tab
|
| 171 |
-
if len(recs) > 1:
|
| 172 |
-
answer.append("More results available in the Find Books tab.")
|
| 173 |
-
return "\n\n".join(answer)
|
| 174 |
-
except Exception as e:
|
| 175 |
-
logger.error(f"chat_response error: {e}")
|
| 176 |
-
return "Error processing your question. Try again later."
|
| 177 |
-
|
| 178 |
-
# --- Business Logic: Tab 3 (Marketing) ---
|
| 179 |
-
def generate_marketing_copy(product_name, features, target_audience):
|
| 180 |
-
# Placeholder for Marketing Content Engine
|
| 181 |
-
# from src.marketing.guardrails import SafetyCheck...
|
| 182 |
-
return f"""
|
| 183 |
-
📣 **CALLING ALL {target_audience.upper()}!**
|
| 184 |
-
|
| 185 |
-
Presenting **{product_name}** — the treasure you've been seeking.
|
| 186 |
-
|
| 187 |
-
✨ **Why you'll love it:**
|
| 188 |
-
{features}
|
| 189 |
-
|
| 190 |
-
Perfect for your collection. Add it to your shelf today.
|
| 191 |
-
"""
|
| 192 |
-
|
| 193 |
-
# --- UI Construction ---
|
| 194 |
-
with gr.Blocks(title="Paper Shelf - Book Discovery", theme=gr.themes.Soft()) as dashboard:
|
| 195 |
-
|
| 196 |
-
gr.Markdown("# 📚 Paper Shelf")
|
| 197 |
-
gr.Markdown("Intelligent book discovery powered by semantic search: **Find Books**, **Ask Questions**, **Generate Marketing Copy**.")
|
| 198 |
-
|
| 199 |
-
with gr.Tabs():
|
| 200 |
-
|
| 201 |
-
# --- Tab 1: Discovery ---
|
| 202 |
-
with gr.TabItem("🔍 Find Books (Search & Recommendations)"):
|
| 203 |
-
rec_state = gr.State([]) # store full recommendation data
|
| 204 |
-
qa_hint = gr.State("")
|
| 205 |
-
sel_idx = gr.State(-1)
|
| 206 |
-
with gr.Row():
|
| 207 |
-
with gr.Column(scale=3):
|
| 208 |
-
q_input = gr.Textbox(label="What are you looking for?", placeholder="e.g., a mystery novel with fast pacing")
|
| 209 |
-
with gr.Column(scale=1):
|
| 210 |
-
cat_input = gr.Dropdown(label="Category", choices=categories, value="All")
|
| 211 |
-
tone_input = gr.Dropdown(label="Mood/Tone", choices=tones, value="All")
|
| 212 |
-
|
| 213 |
-
btn_rec = gr.Button("Find Books", variant="primary")
|
| 214 |
-
gallery = gr.Gallery(label="Results", columns=4, height="auto")
|
| 215 |
-
with gr.Row():
|
| 216 |
-
with gr.Column(scale=2):
|
| 217 |
-
title_info = gr.Markdown(label="Book Info")
|
| 218 |
-
desc_info = gr.Markdown(label="Description")
|
| 219 |
-
with gr.Column(scale=1):
|
| 220 |
-
rank_info = gr.Markdown(label="Ranking")
|
| 221 |
-
comments_info = gr.Markdown(label="Reviews")
|
| 222 |
-
qa_hint_md = gr.Markdown(label="Ask the Assistant", value="(Click a book to see suggested questions)")
|
| 223 |
-
|
| 224 |
-
with gr.Row():
|
| 225 |
-
btn_fav = gr.Button("⭐ Add to Favorites", variant="secondary")
|
| 226 |
-
btn_high = gr.Button("✨ Generate Highlights", variant="primary")
|
| 227 |
-
fav_status = gr.Markdown(label="Status")
|
| 228 |
-
highlights_md = gr.Markdown(label="Personalized Highlights")
|
| 229 |
-
|
| 230 |
-
btn_rec.click(recommend_books, [q_input, cat_input, tone_input], [gallery, rec_state])
|
| 231 |
-
gallery.select(show_book_details, [rec_state], [title_info, rank_info, comments_info, desc_info, qa_hint_md, sel_idx])
|
| 232 |
-
btn_fav.click(add_to_favorites, [sel_idx, rec_state], [fav_status])
|
| 233 |
-
btn_high.click(generate_highlights, [sel_idx, rec_state], [highlights_md])
|
| 234 |
-
|
| 235 |
-
# --- Tab 2: AI Assistant ---
|
| 236 |
-
with gr.TabItem("💬 Ask Questions (RAG Assistant)"):
|
| 237 |
-
chatbot = gr.ChatInterface(
|
| 238 |
-
fn=chat_response,
|
| 239 |
-
examples=["Is there a mystery with time travel?", "Recommend sci-fi with female protagonists"],
|
| 240 |
-
title="Intelligent Book Assistant",
|
| 241 |
-
description="Search and learn about books through conversational AI."
|
| 242 |
-
)
|
| 243 |
-
|
| 244 |
-
# --- Tab 3: Marketing ---
|
| 245 |
-
with gr.TabItem("✍️ Create Marketing Copy (GenAI)"):
|
| 246 |
-
with gr.Row():
|
| 247 |
-
m_name = gr.Textbox(label="Book Title/Hook", value="The Hobbit - First Edition, Near Mint")
|
| 248 |
-
m_feat = gr.Textbox(label="Key Features/Condition", value="Near mint condition, no markings, ships worldwide")
|
| 249 |
-
m_aud = gr.Textbox(label="Target Audience", value="Fantasy enthusiasts, collectors")
|
| 250 |
-
|
| 251 |
-
btn_gen = gr.Button("Generate Listing", variant="primary")
|
| 252 |
-
m_out = gr.Markdown(label="Generated Copy")
|
| 253 |
-
|
| 254 |
-
btn_gen.click(generate_marketing_copy, [m_name, m_feat, m_aud], m_out)
|
| 255 |
-
|
| 256 |
-
if __name__ == "__main__":
|
| 257 |
-
import os
|
| 258 |
-
assets_path = os.path.join(os.path.dirname(__file__), "assets")
|
| 259 |
-
dashboard.launch(
|
| 260 |
-
server_name="0.0.0.0",
|
| 261 |
-
server_port=7860,
|
| 262 |
-
allowed_paths=[assets_path],
|
| 263 |
-
share=True
|
| 264 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
legacy/deploy.sh
DELETED
|
@@ -1,47 +0,0 @@
|
|
| 1 |
-
#!/bin/bash
|
| 2 |
-
|
| 3 |
-
echo "🚀 准备部署到 Hugging Face Spaces..."
|
| 4 |
-
|
| 5 |
-
# 检查必要文件
|
| 6 |
-
echo "📋 检查必要文件..."
|
| 7 |
-
required_files=("gradio-dashboard.py" "books_with_emotions.csv" "books_descriptions.txt" "cover-not-found.jpg" "requirements.txt")
|
| 8 |
-
|
| 9 |
-
for file in "${required_files[@]}"; do
|
| 10 |
-
if [ -f "$file" ]; then
|
| 11 |
-
echo "✅ $file 存在"
|
| 12 |
-
else
|
| 13 |
-
echo "❌ $file 缺失"
|
| 14 |
-
exit 1
|
| 15 |
-
fi
|
| 16 |
-
done
|
| 17 |
-
|
| 18 |
-
# 重命名主文件为 app.py (Hugging Face 标准)
|
| 19 |
-
if [ -f "gradio-dashboard.py" ]; then
|
| 20 |
-
cp gradio-dashboard.py app.py
|
| 21 |
-
echo "✅ 已创建 app.py"
|
| 22 |
-
fi
|
| 23 |
-
|
| 24 |
-
# 检查 Git 状态
|
| 25 |
-
echo "📝 检查 Git 状态..."
|
| 26 |
-
if [ -d ".git" ]; then
|
| 27 |
-
echo "✅ Git 仓库已初始化"
|
| 28 |
-
git status
|
| 29 |
-
else
|
| 30 |
-
echo "⚠️ 未检测到 Git 仓库,请先运行:"
|
| 31 |
-
echo " git init"
|
| 32 |
-
echo " git add ."
|
| 33 |
-
echo " git commit -m '准备部署'"
|
| 34 |
-
echo " git remote add origin https://github.com/你的用户名/book-recommender.git"
|
| 35 |
-
echo " git push -u origin main"
|
| 36 |
-
fi
|
| 37 |
-
|
| 38 |
-
echo ""
|
| 39 |
-
echo "🎯 下一步操作:"
|
| 40 |
-
echo "1. 访问 https://huggingface.co/spaces"
|
| 41 |
-
echo "2. 点击 'Create new Space'"
|
| 42 |
-
echo "3. 选择 'Gradio' SDK"
|
| 43 |
-
echo "4. 连接你的 GitHub 仓库"
|
| 44 |
-
echo "5. 在 Settings 中添加 HUGGINGFACEHUB_API_TOKEN"
|
| 45 |
-
echo "6. 等待自动部署完成"
|
| 46 |
-
echo ""
|
| 47 |
-
echo "📖 详细说明请查看 HUGGINGFACE_DEPLOYMENT.md"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
legacy/download_fix.py
DELETED
|
@@ -1,11 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
|
| 3 |
-
from huggingface_hub import snapshot_download
|
| 4 |
-
|
| 5 |
-
print("🚀 Downloading model from hf-mirror...")
|
| 6 |
-
snapshot_download(
|
| 7 |
-
repo_id="sentence-transformers/all-MiniLM-L6-v2",
|
| 8 |
-
ignore_patterns=["*.bin", "*.h5", "*.ot"], # 只下载 safetensors,省流
|
| 9 |
-
resume_download=True
|
| 10 |
-
)
|
| 11 |
-
print("✅ Download Complete!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Pipeline scripts package
|
scripts/data/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Data processing scripts
|
scripts/data/build_books_basic_info.py
CHANGED
|
@@ -1,48 +1,48 @@
|
|
| 1 |
-
import pandas as pd
|
| 2 |
import csv
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
)
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
merged =
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
"
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
"
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
# 保存新表,强制所有字段加引号,防止description等字段被截断
|
| 47 |
-
merged.to_csv("data/books_basic_info.csv", index=False, quoting=csv.QUOTE_ALL, quotechar='"', escapechar='\\')
|
| 48 |
-
print("已生成 data/books_basic_info.csv,包含基础书籍信息字段。")
|
|
|
|
|
|
|
| 1 |
import csv
|
| 2 |
+
import logging
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import pandas as pd
|
| 6 |
|
| 7 |
+
logger = logging.getLogger(__name__)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def run(
|
| 11 |
+
books_path: Path = Path("data/books_data.csv"),
|
| 12 |
+
ratings_path: Path = Path("data/Books_rating.csv"),
|
| 13 |
+
output_path: Path = Path("data/books_basic_info.csv"),
|
| 14 |
+
) -> None:
|
| 15 |
+
"""Build books basic info from raw data. Callable from Pipeline."""
|
| 16 |
+
books_data = pd.read_csv(
|
| 17 |
+
str(books_path),
|
| 18 |
+
engine="python",
|
| 19 |
+
quotechar='"',
|
| 20 |
+
escapechar='\\',
|
| 21 |
+
on_bad_lines='skip',
|
| 22 |
+
)
|
| 23 |
+
ratings = pd.read_csv(
|
| 24 |
+
str(ratings_path),
|
| 25 |
+
engine="python",
|
| 26 |
+
quotechar='"',
|
| 27 |
+
escapechar='\\',
|
| 28 |
+
on_bad_lines='skip',
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
books_cols = ["Title", "description", "authors", "image", "publisher", "publishedDate", "categories"]
|
| 32 |
+
books_data = books_data[books_cols]
|
| 33 |
+
ratings = ratings[["Title", "Id", "review/score"]].drop_duplicates(subset=["Title"])
|
| 34 |
+
merged = books_data.merge(ratings, on="Title", how="left")
|
| 35 |
+
merged = merged.rename(columns={
|
| 36 |
+
"Id": "isbn10", "Title": "title", "authors": "authors", "description": "description",
|
| 37 |
+
"image": "image", "publisher": "publisher", "publishedDate": "publishedDate",
|
| 38 |
+
"categories": "categories", "review/score": "average_rating"
|
| 39 |
+
})
|
| 40 |
+
merged["isbn13"] = None
|
| 41 |
+
|
| 42 |
+
merged.to_csv(str(output_path), index=False, quoting=csv.QUOTE_ALL, quotechar='"', escapechar='\\')
|
| 43 |
+
logger.info("Saved %s", output_path)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
if __name__ == "__main__":
|
| 47 |
+
logging.basicConfig(level=logging.INFO)
|
| 48 |
+
run()
|
|
|
|
|
|
|
|
|
scripts/data/clean_data.py
CHANGED
|
@@ -226,23 +226,21 @@ def analyze_data_quality(df: pd.DataFrame, text_columns: list) -> dict:
|
|
| 226 |
return stats
|
| 227 |
|
| 228 |
|
| 229 |
-
def
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
raise FileNotFoundError(f"Input file not found: {args.input}")
|
| 242 |
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
df = pd.read_csv(args.input)
|
| 246 |
logger.info(f"Loaded {len(df):,} records")
|
| 247 |
|
| 248 |
# Define columns to clean
|
|
@@ -261,31 +259,42 @@ def main():
|
|
| 261 |
for col, s in stats_before.items():
|
| 262 |
logger.info(f" {col}: {s['has_html']} HTML, {s['has_url']} URLs, avg_len={s['avg_length']:.0f}")
|
| 263 |
|
| 264 |
-
if
|
| 265 |
logger.info("\n[DRY RUN] No changes will be saved")
|
| 266 |
return
|
| 267 |
-
|
| 268 |
-
# Clean
|
| 269 |
logger.info("\n🧹 Cleaning data...")
|
| 270 |
df = clean_dataframe(df, text_columns, max_lengths)
|
| 271 |
-
|
| 272 |
-
# Analyze after
|
| 273 |
logger.info("\n📊 Data quality AFTER cleaning:")
|
| 274 |
stats_after = analyze_data_quality(df, text_columns)
|
| 275 |
for col, s in stats_after.items():
|
| 276 |
logger.info(f" {col}: {s['has_html']} HTML, {s['has_url']} URLs, avg_len={s['avg_length']:.0f}")
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
backup_path = args.output.with_suffix('.csv.bak')
|
| 281 |
logger.info(f"Creating backup: {backup_path}")
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
df.to_csv(args.output, index=False)
|
| 287 |
logger.info("✅ Done!")
|
| 288 |
|
| 289 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
if __name__ == "__main__":
|
| 291 |
main()
|
|
|
|
| 226 |
return stats
|
| 227 |
|
| 228 |
|
| 229 |
+
def run(
|
| 230 |
+
backup: bool = False,
|
| 231 |
+
input_path: Optional[Path] = None,
|
| 232 |
+
output_path: Optional[Path] = None,
|
| 233 |
+
dry_run: bool = False,
|
| 234 |
+
) -> None:
|
| 235 |
+
"""Clean text data. Callable from Pipeline."""
|
| 236 |
+
input_path = input_path or Path("data/books_processed.csv")
|
| 237 |
+
output_path = output_path or input_path
|
| 238 |
+
|
| 239 |
+
if not input_path.exists():
|
| 240 |
+
raise FileNotFoundError(f"Input file not found: {input_path}")
|
|
|
|
| 241 |
|
| 242 |
+
logger.info(f"Loading data from {input_path}")
|
| 243 |
+
df = pd.read_csv(input_path)
|
|
|
|
| 244 |
logger.info(f"Loaded {len(df):,} records")
|
| 245 |
|
| 246 |
# Define columns to clean
|
|
|
|
| 259 |
for col, s in stats_before.items():
|
| 260 |
logger.info(f" {col}: {s['has_html']} HTML, {s['has_url']} URLs, avg_len={s['avg_length']:.0f}")
|
| 261 |
|
| 262 |
+
if dry_run:
|
| 263 |
logger.info("\n[DRY RUN] No changes will be saved")
|
| 264 |
return
|
| 265 |
+
|
|
|
|
| 266 |
logger.info("\n🧹 Cleaning data...")
|
| 267 |
df = clean_dataframe(df, text_columns, max_lengths)
|
| 268 |
+
|
|
|
|
| 269 |
logger.info("\n📊 Data quality AFTER cleaning:")
|
| 270 |
stats_after = analyze_data_quality(df, text_columns)
|
| 271 |
for col, s in stats_after.items():
|
| 272 |
logger.info(f" {col}: {s['has_html']} HTML, {s['has_url']} URLs, avg_len={s['avg_length']:.0f}")
|
| 273 |
+
|
| 274 |
+
if backup and output_path.exists():
|
| 275 |
+
backup_path = output_path.with_suffix('.csv.bak')
|
|
|
|
| 276 |
logger.info(f"Creating backup: {backup_path}")
|
| 277 |
+
output_path.rename(backup_path)
|
| 278 |
+
|
| 279 |
+
logger.info(f"\n💾 Saving to {output_path}")
|
| 280 |
+
df.to_csv(output_path, index=False)
|
|
|
|
| 281 |
logger.info("✅ Done!")
|
| 282 |
|
| 283 |
|
| 284 |
+
def main():
|
| 285 |
+
parser = argparse.ArgumentParser(description="Clean text data in books dataset")
|
| 286 |
+
parser.add_argument("--input", type=Path, default=Path("data/books_processed.csv"))
|
| 287 |
+
parser.add_argument("--output", type=Path, default=None)
|
| 288 |
+
parser.add_argument("--dry-run", action="store_true", help="Analyze without saving")
|
| 289 |
+
parser.add_argument("--backup", action="store_true", help="Create backup before overwriting")
|
| 290 |
+
args = parser.parse_args()
|
| 291 |
+
run(
|
| 292 |
+
backup=args.backup,
|
| 293 |
+
input_path=args.input,
|
| 294 |
+
output_path=args.output or args.input,
|
| 295 |
+
dry_run=args.dry_run,
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
|
| 299 |
if __name__ == "__main__":
|
| 300 |
main()
|
scripts/data/generate_emotions.py
CHANGED
|
@@ -76,72 +76,32 @@ def scores_to_vector(scores: List[Dict[str, float]]) -> Dict[str, float]:
|
|
| 76 |
return mapped
|
| 77 |
|
| 78 |
|
| 79 |
-
def
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
raise FileNotFoundError(f"Input file not found: {args.input}")
|
| 92 |
-
|
| 93 |
-
logger.info("Loading data from %s", args.input)
|
| 94 |
-
df = pd.read_csv(args.input)
|
| 95 |
if "description" not in df.columns:
|
| 96 |
raise ValueError("Input CSV must have a 'description' column")
|
| 97 |
|
| 98 |
-
if args.max_rows:
|
| 99 |
-
df = df.head(args.max_rows)
|
| 100 |
-
logger.info("Truncated to %d rows for max_rows", len(df))
|
| 101 |
-
|
| 102 |
-
n = len(df)
|
| 103 |
-
# Normalize device arg
|
| 104 |
-
dev: str | int | None
|
| 105 |
-
if args.device is None:
|
| 106 |
-
dev = None
|
| 107 |
-
else:
|
| 108 |
-
if isinstance(args.device, str) and args.device.lower() == "mps":
|
| 109 |
-
dev = "mps"
|
| 110 |
-
else:
|
| 111 |
-
try:
|
| 112 |
-
dev = int(args.device)
|
| 113 |
-
except ValueError:
|
| 114 |
-
dev = None
|
| 115 |
-
model = load_model(dev)
|
| 116 |
-
|
| 117 |
-
# Prepare containers
|
| 118 |
for col in TARGET_LABELS:
|
| 119 |
if col not in df.columns:
|
| 120 |
df[col] = 0.0
|
| 121 |
|
| 122 |
-
|
| 123 |
-
if args.resume and args.output.exists():
|
| 124 |
-
logger.info("Resume enabled: loading existing output from %s", args.output)
|
| 125 |
-
df_prev = pd.read_csv(args.output)
|
| 126 |
-
for col in TARGET_LABELS:
|
| 127 |
-
if col in df_prev.columns:
|
| 128 |
-
df[col] = df_prev[col]
|
| 129 |
-
|
| 130 |
texts = df["description"].fillna("").astype(str).tolist()
|
| 131 |
-
|
| 132 |
-
checkpoint = max(1, args.checkpoint)
|
| 133 |
-
|
| 134 |
-
logger.info("Scoring %d descriptions (batch=%d, checkpoint=%d)...", n, batch, checkpoint)
|
| 135 |
-
total_batches = (n + batch - 1) // batch
|
| 136 |
-
for bidx, start in enumerate(tqdm(range(0, n, batch), total=total_batches)):
|
| 137 |
-
end = min(start + batch, n)
|
| 138 |
-
|
| 139 |
-
# Skip already-computed rows when resuming (all scores > 0)
|
| 140 |
-
if args.resume:
|
| 141 |
-
existing = df.loc[start:end-1, TARGET_LABELS].values
|
| 142 |
-
if np.all(existing > 0):
|
| 143 |
-
continue
|
| 144 |
|
|
|
|
|
|
|
|
|
|
| 145 |
chunk = texts[start:end]
|
| 146 |
outputs = model(chunk, truncation=True, max_length=512, top_k=None)
|
| 147 |
for i, out in enumerate(outputs):
|
|
@@ -150,13 +110,29 @@ def main():
|
|
| 150 |
for col in TARGET_LABELS:
|
| 151 |
df.at[idx, col] = vec[col]
|
| 152 |
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
|
| 161 |
|
| 162 |
if __name__ == "__main__":
|
|
|
|
| 76 |
return mapped
|
| 77 |
|
| 78 |
|
| 79 |
+
def run(
|
| 80 |
+
input_path: Path = Path("data/books_processed.csv"),
|
| 81 |
+
output_path: Path = Path("data/books_processed.csv"),
|
| 82 |
+
batch_size: int = 16,
|
| 83 |
+
device=None,
|
| 84 |
+
) -> None:
|
| 85 |
+
"""Generate emotion scores. Callable from Pipeline."""
|
| 86 |
+
if not input_path.exists():
|
| 87 |
+
raise FileNotFoundError(f"Input file not found: {input_path}")
|
| 88 |
+
|
| 89 |
+
logger.info("Loading data from %s", input_path)
|
| 90 |
+
df = pd.read_csv(input_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
if "description" not in df.columns:
|
| 92 |
raise ValueError("Input CSV must have a 'description' column")
|
| 93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
for col in TARGET_LABELS:
|
| 95 |
if col not in df.columns:
|
| 96 |
df[col] = 0.0
|
| 97 |
|
| 98 |
+
model = load_model(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
texts = df["description"].fillna("").astype(str).tolist()
|
| 100 |
+
n = len(df)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
+
logger.info("Scoring %d descriptions...", n)
|
| 103 |
+
for start in tqdm(range(0, n, batch_size)):
|
| 104 |
+
end = min(start + batch_size, n)
|
| 105 |
chunk = texts[start:end]
|
| 106 |
outputs = model(chunk, truncation=True, max_length=512, top_k=None)
|
| 107 |
for i, out in enumerate(outputs):
|
|
|
|
| 110 |
for col in TARGET_LABELS:
|
| 111 |
df.at[idx, col] = vec[col]
|
| 112 |
|
| 113 |
+
logger.info("Writing to %s", output_path)
|
| 114 |
+
df.to_csv(output_path, index=False)
|
| 115 |
+
|
| 116 |
|
| 117 |
+
def main():
|
| 118 |
+
ap = argparse.ArgumentParser(description="Generate emotion scores from descriptions")
|
| 119 |
+
ap.add_argument("--input", type=Path, default=Path("data/books_processed.csv"))
|
| 120 |
+
ap.add_argument("--output", type=Path, default=Path("data/books_processed.csv"))
|
| 121 |
+
ap.add_argument("--batch-size", type=int, default=16)
|
| 122 |
+
ap.add_argument("--max-rows", type=int, default=None, help="Optional cap for debugging")
|
| 123 |
+
ap.add_argument("--device", default=None, help="'mps' for Apple GPU, CUDA device id, or omit for CPU")
|
| 124 |
+
ap.add_argument("--checkpoint", type=int, default=5000, help="Rows between checkpoint writes")
|
| 125 |
+
ap.add_argument("--resume", action="store_true", help="Resume if output exists (skip rows with scores)")
|
| 126 |
+
args = ap.parse_args()
|
| 127 |
+
dev = None
|
| 128 |
+
if args.device:
|
| 129 |
+
dev = "mps" if str(args.device).lower() == "mps" else (int(args.device) if str(args.device).isdigit() else None)
|
| 130 |
+
run(
|
| 131 |
+
input_path=args.input,
|
| 132 |
+
output_path=args.output,
|
| 133 |
+
batch_size=args.batch_size,
|
| 134 |
+
device=dev,
|
| 135 |
+
)
|
| 136 |
|
| 137 |
|
| 138 |
if __name__ == "__main__":
|
scripts/data/generate_tags.py
CHANGED
|
@@ -116,6 +116,30 @@ def compute_tags(corpus: List[str], top_n: int, max_features: int, min_df: int,
|
|
| 116 |
return tags
|
| 117 |
|
| 118 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
def main():
|
| 120 |
parser = argparse.ArgumentParser(description="Generate per-book tags from descriptions")
|
| 121 |
parser.add_argument("--input", type=Path, default=Path("data/books_processed.csv"))
|
|
@@ -125,29 +149,15 @@ def main():
|
|
| 125 |
parser.add_argument("--min-df", type=int, default=5)
|
| 126 |
parser.add_argument("--max-df", type=float, default=0.5)
|
| 127 |
args = parser.parse_args()
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
logger.info("Loading data from %s", args.input)
|
| 133 |
-
df = pd.read_csv(args.input)
|
| 134 |
-
if "description" not in df.columns:
|
| 135 |
-
raise ValueError("Input CSV must have a 'description' column")
|
| 136 |
-
|
| 137 |
-
corpus = [normalize_text(x) for x in df["description"].fillna("").astype(str).tolist()]
|
| 138 |
-
tags = compute_tags(
|
| 139 |
-
corpus,
|
| 140 |
top_n=args.top_n,
|
| 141 |
max_features=args.max_features,
|
| 142 |
min_df=args.min_df,
|
| 143 |
max_df=args.max_df,
|
| 144 |
)
|
| 145 |
|
| 146 |
-
df["tags"] = tags
|
| 147 |
-
logger.info("Writing tagged data to %s", args.output)
|
| 148 |
-
df.to_csv(args.output, index=False)
|
| 149 |
-
logger.info("Done. Sample tags: %s", tags[0:3])
|
| 150 |
-
|
| 151 |
|
| 152 |
if __name__ == "__main__":
|
| 153 |
main()
|
|
|
|
| 116 |
return tags
|
| 117 |
|
| 118 |
|
| 119 |
+
def run(
|
| 120 |
+
input_path: Path = Path("data/books_processed.csv"),
|
| 121 |
+
output_path: Path = Path("data/books_processed.csv"),
|
| 122 |
+
top_n: int = 8,
|
| 123 |
+
max_features: int = 60000,
|
| 124 |
+
min_df: int = 5,
|
| 125 |
+
max_df: float = 0.5,
|
| 126 |
+
) -> None:
|
| 127 |
+
"""Generate per-book tags. Callable from Pipeline."""
|
| 128 |
+
if not input_path.exists():
|
| 129 |
+
raise FileNotFoundError(f"Input file not found: {input_path}")
|
| 130 |
+
|
| 131 |
+
logger.info("Loading data from %s", input_path)
|
| 132 |
+
df = pd.read_csv(input_path)
|
| 133 |
+
if "description" not in df.columns:
|
| 134 |
+
raise ValueError("Input CSV must have a 'description' column")
|
| 135 |
+
|
| 136 |
+
corpus = [normalize_text(x) for x in df["description"].fillna("").astype(str).tolist()]
|
| 137 |
+
tags = compute_tags(corpus, top_n=top_n, max_features=max_features, min_df=min_df, max_df=max_df)
|
| 138 |
+
df["tags"] = tags
|
| 139 |
+
logger.info("Writing tagged data to %s", output_path)
|
| 140 |
+
df.to_csv(output_path, index=False)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
def main():
|
| 144 |
parser = argparse.ArgumentParser(description="Generate per-book tags from descriptions")
|
| 145 |
parser.add_argument("--input", type=Path, default=Path("data/books_processed.csv"))
|
|
|
|
| 149 |
parser.add_argument("--min-df", type=int, default=5)
|
| 150 |
parser.add_argument("--max-df", type=float, default=0.5)
|
| 151 |
args = parser.parse_args()
|
| 152 |
+
run(
|
| 153 |
+
input_path=args.input,
|
| 154 |
+
output_path=args.output,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
top_n=args.top_n,
|
| 156 |
max_features=args.max_features,
|
| 157 |
min_df=args.min_df,
|
| 158 |
max_df=args.max_df,
|
| 159 |
)
|
| 160 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
|
| 162 |
if __name__ == "__main__":
|
| 163 |
main()
|
scripts/data/split_rec_data.py
CHANGED
|
@@ -4,7 +4,7 @@
|
|
| 4 |
|
| 5 |
划分策略: 时序划分 (Leave-Last-Out)
|
| 6 |
- 每个用户的最后一次评分 → test
|
| 7 |
-
- 每个用户的倒数第二次评分 → val
|
| 8 |
- 其余评分 → train
|
| 9 |
|
| 10 |
只保留评分 >= 3 次的用户 (有足够历史)
|
|
@@ -15,128 +15,71 @@ import numpy as np
|
|
| 15 |
from pathlib import Path
|
| 16 |
from tqdm import tqdm
|
| 17 |
import time
|
|
|
|
| 18 |
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
df.
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
group = group.sort_values('timestamp')
|
| 86 |
-
n = len(group)
|
| 87 |
-
|
| 88 |
-
# 最后一条 → test
|
| 89 |
-
test_list.append(group.iloc[-1])
|
| 90 |
-
|
| 91 |
-
# 倒数第二条 → val
|
| 92 |
-
val_list.append(group.iloc[-2])
|
| 93 |
-
|
| 94 |
-
# 其余 → train
|
| 95 |
-
train_list.extend(group.iloc[:-2].to_dict('records'))
|
| 96 |
-
|
| 97 |
-
# 转换为 DataFrame
|
| 98 |
-
train_df = pd.DataFrame(train_list)
|
| 99 |
-
val_df = pd.DataFrame(val_list)
|
| 100 |
-
test_df = pd.DataFrame(test_list)
|
| 101 |
-
|
| 102 |
-
print(f' 训练集: {len(train_df):,} ({len(train_df)/len(df)*100:.1f}%)')
|
| 103 |
-
print(f' 验证集: {len(val_df):,} ({len(val_df)/len(df)*100:.1f}%)')
|
| 104 |
-
print(f' 测试集: {len(test_df):,} ({len(test_df)/len(df)*100:.1f}%)')
|
| 105 |
-
|
| 106 |
-
# ==================== 5. 保存数据 ====================
|
| 107 |
-
print('\n[5/5] 保存数据...')
|
| 108 |
-
|
| 109 |
-
train_df.to_csv(OUTPUT_DIR / 'train.csv', index=False)
|
| 110 |
-
val_df.to_csv(OUTPUT_DIR / 'val.csv', index=False)
|
| 111 |
-
test_df.to_csv(OUTPUT_DIR / 'test.csv', index=False)
|
| 112 |
-
|
| 113 |
-
# 保存用户列表 (用于后续评估)
|
| 114 |
-
active_users_df = pd.DataFrame({'user_id': active_users})
|
| 115 |
-
active_users_df.to_csv(OUTPUT_DIR / 'active_users.csv', index=False)
|
| 116 |
-
|
| 117 |
-
# 保存统计信息
|
| 118 |
-
stats = {
|
| 119 |
-
'total_records': len(df),
|
| 120 |
-
'train_records': len(train_df),
|
| 121 |
-
'val_records': len(val_df),
|
| 122 |
-
'test_records': len(test_df),
|
| 123 |
-
'active_users': len(active_users),
|
| 124 |
-
'books': df['isbn'].nunique(),
|
| 125 |
-
}
|
| 126 |
-
|
| 127 |
-
with open(OUTPUT_DIR / 'stats.txt', 'w') as f:
|
| 128 |
-
for k, v in stats.items():
|
| 129 |
-
f.write(f'{k}: {v:,}\n')
|
| 130 |
-
|
| 131 |
-
elapsed = time.time() - start_time
|
| 132 |
-
|
| 133 |
-
print('\n' + '='*60)
|
| 134 |
-
print('✅ 数据划分完成!')
|
| 135 |
-
print('='*60)
|
| 136 |
-
print(f'输出目录: {OUTPUT_DIR}')
|
| 137 |
-
print(f' - train.csv: {len(train_df):,} 条')
|
| 138 |
-
print(f' - val.csv: {len(val_df):,} 条')
|
| 139 |
-
print(f' - test.csv: {len(test_df):,} 条')
|
| 140 |
-
print(f' - active_users.csv: {len(active_users):,} 用户')
|
| 141 |
-
print(f'执行时间: {elapsed:.1f}秒')
|
| 142 |
-
print('='*60)
|
|
|
|
| 4 |
|
| 5 |
划分策略: 时序划分 (Leave-Last-Out)
|
| 6 |
- 每个用户的最后一次评分 → test
|
| 7 |
+
- 每个用户的倒数第二次评分 → val
|
| 8 |
- 其余评分 → train
|
| 9 |
|
| 10 |
只保留评分 >= 3 次的用户 (有足够历史)
|
|
|
|
| 15 |
from pathlib import Path
|
| 16 |
from tqdm import tqdm
|
| 17 |
import time
|
| 18 |
+
import logging
|
| 19 |
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
DATA_PATH = Path("data/raw/Books_rating.csv")
|
| 23 |
+
OUTPUT_DIR = Path("data/rec")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def run(
|
| 27 |
+
data_path: Path = DATA_PATH,
|
| 28 |
+
output_dir: Path = OUTPUT_DIR,
|
| 29 |
+
) -> None:
|
| 30 |
+
"""Split train/val/test with Leave-Last-Out. Callable from Pipeline."""
|
| 31 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 32 |
+
start_time = time.time()
|
| 33 |
+
|
| 34 |
+
logger.info("Loading raw ratings...")
|
| 35 |
+
df = pd.read_csv(data_path, usecols=['Id', 'User_id', 'review/score', 'review/time', 'review/text'])
|
| 36 |
+
df.columns = ['isbn', 'user_id', 'rating', 'timestamp', 'review']
|
| 37 |
+
logger.info(f" Records: {len(df):,}, Users: {df['user_id'].nunique():,}, Items: {df['isbn'].nunique():,}")
|
| 38 |
+
|
| 39 |
+
logger.info("Cleaning data...")
|
| 40 |
+
df = df.drop_duplicates(subset=['user_id', 'isbn'], keep='last')
|
| 41 |
+
df = df.dropna(subset=['rating', 'timestamp'])
|
| 42 |
+
df = df[df['rating'] > 0]
|
| 43 |
+
|
| 44 |
+
logger.info("Filtering active users (>=3 interactions)...")
|
| 45 |
+
user_counts = df.groupby('user_id').size()
|
| 46 |
+
active_users = user_counts[user_counts >= 3].index
|
| 47 |
+
df = df[df['user_id'].isin(active_users)]
|
| 48 |
+
logger.info(f" Active users: {len(active_users):,}, Records: {len(df):,}")
|
| 49 |
+
|
| 50 |
+
logger.info("Splitting train/val/test (Leave-Last-Out)...")
|
| 51 |
+
df = df.sort_values(['user_id', 'timestamp'])
|
| 52 |
+
|
| 53 |
+
train_list = []
|
| 54 |
+
val_list = []
|
| 55 |
+
test_list = []
|
| 56 |
+
|
| 57 |
+
for user_id, group in tqdm(df.groupby('user_id'), desc=" Splitting"):
|
| 58 |
+
group = group.sort_values('timestamp')
|
| 59 |
+
test_list.append(group.iloc[-1])
|
| 60 |
+
val_list.append(group.iloc[-2])
|
| 61 |
+
train_list.extend(group.iloc[:-2].to_dict('records'))
|
| 62 |
+
|
| 63 |
+
train_df = pd.DataFrame(train_list)
|
| 64 |
+
val_df = pd.DataFrame(val_list)
|
| 65 |
+
test_df = pd.DataFrame(test_list)
|
| 66 |
+
|
| 67 |
+
logger.info(f" Train: {len(train_df):,}, Val: {len(val_df):,}, Test: {len(test_df):,}")
|
| 68 |
+
|
| 69 |
+
train_df.to_csv(output_dir / 'train.csv', index=False)
|
| 70 |
+
val_df.to_csv(output_dir / 'val.csv', index=False)
|
| 71 |
+
test_df.to_csv(output_dir / 'test.csv', index=False)
|
| 72 |
+
pd.DataFrame({'user_id': active_users}).to_csv(output_dir / 'active_users.csv', index=False)
|
| 73 |
+
|
| 74 |
+
with open(output_dir / 'stats.txt', 'w') as f:
|
| 75 |
+
for k, v in [('total_records', len(df)), ('train_records', len(train_df)),
|
| 76 |
+
('val_records', len(val_df)), ('test_records', len(test_df)),
|
| 77 |
+
('active_users', len(active_users)), ('books', df['isbn'].nunique())]:
|
| 78 |
+
f.write(f'{k}: {v:,}\n')
|
| 79 |
+
|
| 80 |
+
logger.info("Split complete in %.1fs", time.time() - start_time)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
if __name__ == "__main__":
|
| 84 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
| 85 |
+
run()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/model/build_recall_models.py
CHANGED
|
@@ -1,76 +1,61 @@
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
-
Build
|
| 4 |
|
| 5 |
-
|
| 6 |
-
|
| 7 |
|
| 8 |
Usage:
|
| 9 |
python scripts/model/build_recall_models.py
|
| 10 |
|
| 11 |
-
Input:
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
Output:
|
| 15 |
-
- data/model/recall/itemcf.pkl (~1.4 GB)
|
| 16 |
-
- data/model/recall/usercf.pkl (~70 MB)
|
| 17 |
-
- data/model/recall/swing.pkl
|
| 18 |
-
- data/model/recall/popularity.pkl
|
| 19 |
-
- data/model/recall/item2vec.pkl
|
| 20 |
-
|
| 21 |
-
Algorithms:
|
| 22 |
-
- ItemCF: Co-rating similarity with direction weight (forward=1.0, backward=0.7)
|
| 23 |
-
- UserCF: User similarity (Jaccard + activity penalty)
|
| 24 |
-
- Swing: User-pair overlap weighting for substitute relationships
|
| 25 |
-
- Popularity: Rating count with time decay
|
| 26 |
-
- Item2Vec: Word2Vec (Skip-gram) on user interaction sequences
|
| 27 |
"""
|
| 28 |
|
| 29 |
import sys
|
| 30 |
-
import
|
| 31 |
-
|
|
|
|
|
|
|
| 32 |
|
| 33 |
import pandas as pd
|
| 34 |
import logging
|
|
|
|
| 35 |
from src.recall.itemcf import ItemCF
|
| 36 |
from src.recall.usercf import UserCF
|
| 37 |
from src.recall.swing import Swing
|
| 38 |
from src.recall.popularity import PopularityRecall
|
| 39 |
from src.recall.item2vec import Item2Vec
|
| 40 |
|
| 41 |
-
logging.basicConfig(level=logging.INFO, format=
|
| 42 |
logger = logging.getLogger(__name__)
|
| 43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
def main():
|
| 45 |
-
logger.info("Loading training data...")
|
| 46 |
-
df = pd.read_csv(
|
| 47 |
-
|
| 48 |
-
|
| 49 |
logger.info("--- Training ItemCF ---")
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
# 2. UserCF
|
| 54 |
logger.info("--- Training UserCF ---")
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
# 3. Swing
|
| 59 |
logger.info("--- Training Swing ---")
|
| 60 |
-
|
| 61 |
-
swing.fit(df)
|
| 62 |
|
| 63 |
-
# 4. Popularity
|
| 64 |
logger.info("--- Training Popularity ---")
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
# 5. Item2Vec
|
| 69 |
logger.info("--- Training Item2Vec ---")
|
| 70 |
-
|
| 71 |
-
|
|
|
|
| 72 |
|
| 73 |
-
logger.info("Recall models built and saved successfully!")
|
| 74 |
|
| 75 |
if __name__ == "__main__":
|
| 76 |
main()
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
+
Entry script: Build recall models (ItemCF, UserCF, Swing, Popularity, Item2Vec).
|
| 4 |
|
| 5 |
+
All training logic lives in src/recall/*.fit(). This script only loads data,
|
| 6 |
+
imports models, and calls fit().
|
| 7 |
|
| 8 |
Usage:
|
| 9 |
python scripts/model/build_recall_models.py
|
| 10 |
|
| 11 |
+
Input: data/rec/train.csv (columns: user_id, isbn, rating, timestamp)
|
| 12 |
+
Output: data/model/recall/*.pkl, data/recall_models.db (ItemCF)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
"""
|
| 14 |
|
| 15 |
import sys
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
|
| 18 |
+
# Run from project root
|
| 19 |
+
sys.path.insert(0, str(Path(__file__).resolve().parent.parent.parent))
|
| 20 |
|
| 21 |
import pandas as pd
|
| 22 |
import logging
|
| 23 |
+
|
| 24 |
from src.recall.itemcf import ItemCF
|
| 25 |
from src.recall.usercf import UserCF
|
| 26 |
from src.recall.swing import Swing
|
| 27 |
from src.recall.popularity import PopularityRecall
|
| 28 |
from src.recall.item2vec import Item2Vec
|
| 29 |
|
| 30 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
| 31 |
logger = logging.getLogger(__name__)
|
| 32 |
|
| 33 |
+
PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
|
| 34 |
+
TRAIN_PATH = PROJECT_ROOT / "data" / "rec" / "train.csv"
|
| 35 |
+
|
| 36 |
+
|
| 37 |
def main():
|
| 38 |
+
logger.info("Loading training data from %s...", TRAIN_PATH)
|
| 39 |
+
df = pd.read_csv(TRAIN_PATH)
|
| 40 |
+
logger.info("Loaded %d records.", len(df))
|
| 41 |
+
|
| 42 |
logger.info("--- Training ItemCF ---")
|
| 43 |
+
ItemCF().fit(df)
|
| 44 |
+
|
|
|
|
|
|
|
| 45 |
logger.info("--- Training UserCF ---")
|
| 46 |
+
UserCF().fit(df)
|
| 47 |
+
|
|
|
|
|
|
|
| 48 |
logger.info("--- Training Swing ---")
|
| 49 |
+
Swing().fit(df)
|
|
|
|
| 50 |
|
|
|
|
| 51 |
logger.info("--- Training Popularity ---")
|
| 52 |
+
PopularityRecall().fit(df)
|
| 53 |
+
|
|
|
|
|
|
|
| 54 |
logger.info("--- Training Item2Vec ---")
|
| 55 |
+
Item2Vec().fit(df)
|
| 56 |
+
|
| 57 |
+
logger.info("Recall models built and saved successfully.")
|
| 58 |
|
|
|
|
| 59 |
|
| 60 |
if __name__ == "__main__":
|
| 61 |
main()
|
scripts/model/train_sasrec.py
CHANGED
|
@@ -1,204 +1,44 @@
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
-
Train SASRec
|
| 4 |
|
| 5 |
-
|
| 6 |
-
|
| 7 |
|
| 8 |
Usage:
|
| 9 |
python scripts/model/train_sasrec.py
|
| 10 |
|
| 11 |
-
Input:
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
Output:
|
| 16 |
-
- data/model/rec/sasrec_model.pth (model weights)
|
| 17 |
-
- data/rec/user_seq_emb.pkl (user sequence embeddings)
|
| 18 |
-
|
| 19 |
-
Architecture:
|
| 20 |
-
- Self-Attention layers (Transformer encoder)
|
| 21 |
-
- Positional embeddings
|
| 22 |
-
- BCE loss with negative sampling
|
| 23 |
-
|
| 24 |
-
Recommended:
|
| 25 |
-
- GPU: 30 epochs, ~20 minutes
|
| 26 |
-
- The user embeddings are used as features in LGBMRanker and as an independent recall channel
|
| 27 |
"""
|
| 28 |
|
| 29 |
import sys
|
| 30 |
-
import
|
| 31 |
-
|
|
|
|
| 32 |
|
| 33 |
-
import
|
| 34 |
-
import torch.nn as nn
|
| 35 |
-
import torch.optim as optim
|
| 36 |
-
from torch.utils.data import Dataset, DataLoader
|
| 37 |
-
import pickle
|
| 38 |
-
import numpy as np
|
| 39 |
import logging
|
| 40 |
-
from tqdm import tqdm
|
| 41 |
-
from pathlib import Path
|
| 42 |
-
from src.model.sasrec import SASRec
|
| 43 |
|
| 44 |
-
|
|
|
|
|
|
|
| 45 |
logger = logging.getLogger(__name__)
|
| 46 |
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
seq_len = min(len(s), max_len)
|
| 60 |
-
seq_processed[-seq_len:] = s[-seq_len:]
|
| 61 |
-
|
| 62 |
-
self.seqs.append(seq_processed)
|
| 63 |
-
|
| 64 |
-
self.seqs = torch.LongTensor(self.seqs)
|
| 65 |
-
|
| 66 |
-
def __len__(self):
|
| 67 |
-
return len(self.seqs)
|
| 68 |
-
|
| 69 |
-
def __getitem__(self, idx):
|
| 70 |
-
seq = self.seqs[idx]
|
| 71 |
-
|
| 72 |
-
# Determine pos/neg for training
|
| 73 |
-
# Target: seq shifted right (positives)
|
| 74 |
-
pos = np.zeros_like(seq)
|
| 75 |
-
pos[:-1] = seq[1:]
|
| 76 |
-
|
| 77 |
-
# Negatives: random sample
|
| 78 |
-
neg = np.random.randint(1, self.num_items + 1, size=len(seq))
|
| 79 |
-
|
| 80 |
-
return seq, torch.LongTensor(pos), torch.LongTensor(neg)
|
| 81 |
|
| 82 |
-
def train_sasrec():
|
| 83 |
-
max_len = 50
|
| 84 |
-
hidden_dim = 64
|
| 85 |
-
batch_size = 128
|
| 86 |
-
epochs = 30 # Increased from 3
|
| 87 |
-
lr = 1e-4 # Aligned with optimizer
|
| 88 |
-
|
| 89 |
-
data_dir = Path('data/rec')
|
| 90 |
-
logger.info("Loading sequences...")
|
| 91 |
-
with open(data_dir / 'user_sequences.pkl', 'rb') as f:
|
| 92 |
-
seqs_dict = pickle.load(f)
|
| 93 |
-
|
| 94 |
-
with open(data_dir / 'item_map.pkl', 'rb') as f:
|
| 95 |
-
item_map = pickle.load(f)
|
| 96 |
-
num_items = len(item_map)
|
| 97 |
-
|
| 98 |
-
dataset = SeqDataset(seqs_dict, num_items, max_len)
|
| 99 |
-
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
| 100 |
-
|
| 101 |
-
# Check for MPS (Mac GPU) or CUDA
|
| 102 |
-
if torch.cuda.is_available():
|
| 103 |
-
device = torch.device('cuda')
|
| 104 |
-
elif torch.backends.mps.is_available():
|
| 105 |
-
device = torch.device('mps')
|
| 106 |
-
else:
|
| 107 |
-
device = torch.device('cpu')
|
| 108 |
-
|
| 109 |
-
logger.info(f"Training on {device}")
|
| 110 |
-
|
| 111 |
-
model = SASRec(num_items, max_len, hidden_dim).to(device)
|
| 112 |
-
optimizer = optim.Adam(model.parameters(), lr=1e-4) # Reduced LR
|
| 113 |
-
|
| 114 |
-
# BCE Loss for Pos/Neg
|
| 115 |
-
criterion = nn.BCEWithLogitsLoss()
|
| 116 |
-
|
| 117 |
-
model.train()
|
| 118 |
-
for epoch in range(epochs):
|
| 119 |
-
total_loss = 0
|
| 120 |
-
pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")
|
| 121 |
-
for seq, pos, neg in pbar:
|
| 122 |
-
seq = seq.to(device)
|
| 123 |
-
pos = pos.to(device)
|
| 124 |
-
neg = neg.to(device)
|
| 125 |
-
|
| 126 |
-
# Forward pass to get seq embeddings: [B, L, H]
|
| 127 |
-
seq_emb = model(seq) # [B, L, H]
|
| 128 |
-
|
| 129 |
-
# Mask padding (0) in targets
|
| 130 |
-
mask = (pos != 0)
|
| 131 |
-
|
| 132 |
-
# Get Item Embeddings for Pos and Neg
|
| 133 |
-
pos_emb = model.item_emb(pos) # [B, L, H]
|
| 134 |
-
neg_emb = model.item_emb(neg) # [B, L, H]
|
| 135 |
-
|
| 136 |
-
# Calculate logits
|
| 137 |
-
pos_logits = (seq_emb * pos_emb).sum(dim=-1)
|
| 138 |
-
neg_logits = (seq_emb * neg_emb).sum(dim=-1)
|
| 139 |
-
|
| 140 |
-
pos_logits = pos_logits[mask]
|
| 141 |
-
neg_logits = neg_logits[mask]
|
| 142 |
-
|
| 143 |
-
pos_labels = torch.ones_like(pos_logits)
|
| 144 |
-
neg_labels = torch.zeros_like(neg_logits)
|
| 145 |
-
|
| 146 |
-
loss = criterion(pos_logits, pos_labels) + criterion(neg_logits, neg_labels)
|
| 147 |
-
|
| 148 |
-
optimizer.zero_grad()
|
| 149 |
-
loss.backward()
|
| 150 |
-
|
| 151 |
-
# Clip Gradient
|
| 152 |
-
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 153 |
-
|
| 154 |
-
optimizer.step()
|
| 155 |
-
|
| 156 |
-
total_loss += loss.item()
|
| 157 |
-
pbar.set_postfix({'loss': total_loss / (pbar.n + 1)})
|
| 158 |
-
|
| 159 |
-
# Save Model
|
| 160 |
-
torch.save(model.state_dict(), data_dir / '../model/rec/sasrec_model.pth')
|
| 161 |
-
|
| 162 |
-
# Save User Embeddings (Last hidden state)
|
| 163 |
-
logger.info("Extracting User Sequence Embeddings...")
|
| 164 |
-
model.eval()
|
| 165 |
-
user_emb_dict = {}
|
| 166 |
-
|
| 167 |
-
# Create evaluation loader (no shuffle, keep user order)
|
| 168 |
-
# We need to map back to user_ids, so iterate dict directly
|
| 169 |
-
|
| 170 |
-
# Batch processing for inference
|
| 171 |
-
all_users = list(seqs_dict.keys())
|
| 172 |
-
|
| 173 |
-
with torch.no_grad():
|
| 174 |
-
for i in tqdm(range(0, len(all_users), batch_size)):
|
| 175 |
-
batch_users = all_users[i : i+batch_size]
|
| 176 |
-
batch_seqs = []
|
| 177 |
-
for u in batch_users:
|
| 178 |
-
s = seqs_dict[u]
|
| 179 |
-
# Same padding logic
|
| 180 |
-
seq_processed = [0] * max_len
|
| 181 |
-
seq_len = min(len(s), max_len)
|
| 182 |
-
if seq_len > 0:
|
| 183 |
-
seq_processed[-seq_len:] = s[-seq_len:]
|
| 184 |
-
batch_seqs.append(seq_processed)
|
| 185 |
-
|
| 186 |
-
input_tensor = torch.LongTensor(batch_seqs).to(device)
|
| 187 |
-
|
| 188 |
-
# Initial forward
|
| 189 |
-
# Note: During inference, we use the FULL sequence to predict the FUTURE (Test Item)
|
| 190 |
-
# So input is the full available history
|
| 191 |
-
|
| 192 |
-
output = model(input_tensor) # [B, L, H]
|
| 193 |
-
last_state = output[:, -1, :].cpu().numpy() # [B, H]
|
| 194 |
-
|
| 195 |
-
for j, u in enumerate(batch_users):
|
| 196 |
-
user_emb_dict[u] = last_state[j]
|
| 197 |
-
|
| 198 |
-
with open(data_dir / 'user_seq_emb.pkl', 'wb') as f:
|
| 199 |
-
pickle.dump(user_emb_dict, f)
|
| 200 |
-
|
| 201 |
-
logger.info("User Seq Embeddings saved.")
|
| 202 |
|
| 203 |
if __name__ == "__main__":
|
| 204 |
-
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
+
Entry script: Train SASRec sequential recommendation model.
|
| 4 |
|
| 5 |
+
All training logic lives in SASRecRecall.fit(). This script loads data
|
| 6 |
+
and calls fit().
|
| 7 |
|
| 8 |
Usage:
|
| 9 |
python scripts/model/train_sasrec.py
|
| 10 |
|
| 11 |
+
Input: data/rec/train.csv
|
| 12 |
+
Output: data/model/rec/sasrec_model.pth
|
| 13 |
+
data/rec/user_seq_emb.pkl, item_map.pkl, user_sequences.pkl
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
"""
|
| 15 |
|
| 16 |
import sys
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
|
| 19 |
+
sys.path.insert(0, str(Path(__file__).resolve().parent.parent.parent))
|
| 20 |
|
| 21 |
+
import pandas as pd
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
import logging
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
+
from src.recall.sasrec_recall import SASRecRecall
|
| 25 |
+
|
| 26 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
| 27 |
logger = logging.getLogger(__name__)
|
| 28 |
|
| 29 |
+
PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
|
| 30 |
+
TRAIN_PATH = PROJECT_ROOT / "data" / "rec" / "train.csv"
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def main():
|
| 34 |
+
logger.info("Loading training data from %s...", TRAIN_PATH)
|
| 35 |
+
df = pd.read_csv(TRAIN_PATH)
|
| 36 |
+
logger.info("Loaded %d records.", len(df))
|
| 37 |
+
|
| 38 |
+
model = SASRecRecall()
|
| 39 |
+
model.fit(df)
|
| 40 |
+
logger.info("SASRec training complete.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
if __name__ == "__main__":
|
| 44 |
+
main()
|
scripts/model/train_youtube_dnn.py
CHANGED
|
@@ -1,239 +1,45 @@
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
-
Train YoutubeDNN Two-Tower
|
| 4 |
|
| 5 |
-
|
| 6 |
-
|
| 7 |
|
| 8 |
Usage:
|
| 9 |
python scripts/model/train_youtube_dnn.py
|
| 10 |
|
| 11 |
-
Input:
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
Output:
|
| 17 |
-
- data/model/recall/youtube_dnn.pt (model weights)
|
| 18 |
-
- data/model/recall/youtube_dnn_meta.pkl (config + mappings)
|
| 19 |
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
- Item Tower: Embedding(item) + Embedding(category) -> MLP
|
| 23 |
-
- Training: Contrastive loss with in-batch negatives
|
| 24 |
|
| 25 |
-
|
| 26 |
-
- GPU: 10-50 epochs, ~30 minutes
|
| 27 |
-
- CPU: 3-5 epochs for testing only
|
| 28 |
-
"""
|
| 29 |
|
| 30 |
-
import numpy as np
|
| 31 |
import pandas as pd
|
| 32 |
-
import
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
from torch.utils.data import Dataset, DataLoader
|
| 36 |
-
import pickle
|
| 37 |
-
from pathlib import Path
|
| 38 |
-
from tqdm import tqdm
|
| 39 |
-
import sys
|
| 40 |
-
import os
|
| 41 |
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
from src.recall.youtube_dnn import YoutubeDNN
|
| 45 |
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
LR = 0.001
|
| 50 |
-
EMBED_DIM = 64
|
| 51 |
-
MAX_HISTORY = 20
|
| 52 |
-
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 53 |
-
if torch.backends.mps.is_available():
|
| 54 |
-
DEVICE = torch.device('mps')
|
| 55 |
|
| 56 |
-
print(f"Using device: {DEVICE}")
|
| 57 |
|
| 58 |
-
def
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
item_map = pickle.load(f)
|
| 63 |
-
isbn_to_id = item_map
|
| 64 |
-
id_to_isbn = {v: k for k, v in item_map.items()}
|
| 65 |
-
|
| 66 |
-
# Load sequences
|
| 67 |
-
with open('data/rec/user_sequences.pkl', 'rb') as f:
|
| 68 |
-
user_seqs = pickle.load(f)
|
| 69 |
-
|
| 70 |
-
# Load book features for category mapping
|
| 71 |
-
books_df = pd.read_csv('data/books_processed.csv', usecols=['isbn13', 'simple_categories'])
|
| 72 |
-
books_df['isbn'] = books_df['isbn13'].astype(str)
|
| 73 |
-
|
| 74 |
-
# Create category map
|
| 75 |
-
# Categories are often strings like 'Fiction', 'Juvenile Fiction'. We take the first one.
|
| 76 |
-
cate_map = {'<PAD>': 0, '<UNK>': 1}
|
| 77 |
-
item_to_cate = {}
|
| 78 |
-
|
| 79 |
-
print("Building category map...")
|
| 80 |
-
for _, row in books_df.iterrows():
|
| 81 |
-
isbn = row['isbn']
|
| 82 |
-
if isbn in isbn_to_id:
|
| 83 |
-
iid = isbn_to_id[isbn]
|
| 84 |
-
cates = str(row['simple_categories']).split(';')
|
| 85 |
-
main_cate = cates[0].strip() if cates else 'Unknown'
|
| 86 |
-
|
| 87 |
-
if main_cate not in cate_map:
|
| 88 |
-
cate_map[main_cate] = len(cate_map)
|
| 89 |
-
|
| 90 |
-
item_to_cate[iid] = cate_map[main_cate]
|
| 91 |
-
|
| 92 |
-
# Default category for unknown items
|
| 93 |
-
default_cate = cate_map.get('Unknown', 1)
|
| 94 |
-
|
| 95 |
-
return user_seqs, item_to_cate, len(item_map)+1, len(cate_map), default_cate
|
| 96 |
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
self.item_to_cate = item_to_cate
|
| 101 |
-
self.default_cate = default_cate
|
| 102 |
-
self.max_history = max_history
|
| 103 |
-
|
| 104 |
-
print("Generating training samples...")
|
| 105 |
-
# Leave-Last-Out:
|
| 106 |
-
# Last item -> Test
|
| 107 |
-
# 2nd Last -> Val
|
| 108 |
-
# Rest -> Train
|
| 109 |
-
# So we use items 0 to N-3 for training history generation
|
| 110 |
-
|
| 111 |
-
for user, seq in tqdm(user_seqs.items()):
|
| 112 |
-
if len(seq) < 3:
|
| 113 |
-
continue
|
| 114 |
-
|
| 115 |
-
# Use data up to the split point for training
|
| 116 |
-
# Valid Train Set: seq[:-2]
|
| 117 |
-
train_seq = seq[:-2]
|
| 118 |
-
|
| 119 |
-
# Generate sliding window samples
|
| 120 |
-
# minimum history length = 1
|
| 121 |
-
for i in range(1, len(train_seq)):
|
| 122 |
-
target = train_seq[i]
|
| 123 |
-
history = train_seq[:i]
|
| 124 |
-
|
| 125 |
-
# Truncate history
|
| 126 |
-
if len(history) > max_history:
|
| 127 |
-
history = history[-max_history:]
|
| 128 |
-
|
| 129 |
-
self.samples.append((history, target))
|
| 130 |
-
|
| 131 |
-
print(f"Total training samples: {len(self.samples)}")
|
| 132 |
-
|
| 133 |
-
def __len__(self):
|
| 134 |
-
return len(self.samples)
|
| 135 |
-
|
| 136 |
-
def __getitem__(self, idx):
|
| 137 |
-
history, target = self.samples[idx]
|
| 138 |
-
|
| 139 |
-
# Padding history
|
| 140 |
-
padded_hist = np.zeros(self.max_history, dtype=np.int64)
|
| 141 |
-
length = min(len(history), self.max_history)
|
| 142 |
-
if length > 0:
|
| 143 |
-
padded_hist[:length] = history[-length:]
|
| 144 |
-
|
| 145 |
-
target_cate = self.item_to_cate.get(target, self.default_cate)
|
| 146 |
-
|
| 147 |
-
return torch.LongTensor(padded_hist), torch.tensor(target, dtype=torch.long), torch.tensor(target_cate, dtype=torch.long)
|
| 148 |
|
| 149 |
-
def train():
|
| 150 |
-
user_seqs, item_to_cate, vocab_size, cate_vocab_size, default_cate = load_data()
|
| 151 |
-
|
| 152 |
-
dataset = RetrievalDataset(user_seqs, item_to_cate, default_cate, MAX_HISTORY)
|
| 153 |
-
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0) # mp issue on mac sometimes
|
| 154 |
-
|
| 155 |
-
# Model Setup
|
| 156 |
-
user_config = {
|
| 157 |
-
'vocab_size': vocab_size,
|
| 158 |
-
'embed_dim': EMBED_DIM,
|
| 159 |
-
'history_len': MAX_HISTORY
|
| 160 |
-
}
|
| 161 |
-
item_config = {
|
| 162 |
-
'vocab_size': vocab_size,
|
| 163 |
-
'embed_dim': EMBED_DIM,
|
| 164 |
-
'cate_vocab_size': cate_vocab_size,
|
| 165 |
-
'cate_embed_dim': 32
|
| 166 |
-
}
|
| 167 |
-
model_config = {
|
| 168 |
-
'hidden_dims': [128, 64],
|
| 169 |
-
'dropout': 0.1
|
| 170 |
-
}
|
| 171 |
-
|
| 172 |
-
model = YoutubeDNN(user_config, item_config, model_config).to(DEVICE)
|
| 173 |
-
optimizer = optim.Adam(model.parameters(), lr=LR)
|
| 174 |
-
criterion = nn.CrossEntropyLoss() # For In-Batch Negatives
|
| 175 |
-
|
| 176 |
-
print("Start Training...")
|
| 177 |
-
model.train()
|
| 178 |
-
|
| 179 |
-
for epoch in range(EPOCHS):
|
| 180 |
-
total_loss = 0
|
| 181 |
-
steps = 0
|
| 182 |
-
|
| 183 |
-
pbar = tqdm(dataloader)
|
| 184 |
-
for history, target_item, target_cate in pbar:
|
| 185 |
-
history = history.to(DEVICE)
|
| 186 |
-
target_item = target_item.to(DEVICE)
|
| 187 |
-
target_cate = target_cate.to(DEVICE)
|
| 188 |
-
|
| 189 |
-
optimizer.zero_grad()
|
| 190 |
-
|
| 191 |
-
# Get Vectors
|
| 192 |
-
user_vec = model.user_tower(history) # (B, D)
|
| 193 |
-
item_vec = model.item_tower(target_item, target_cate) # (B, D)
|
| 194 |
-
|
| 195 |
-
# Normalize
|
| 196 |
-
user_vec = nn.functional.normalize(user_vec, p=2, dim=1)
|
| 197 |
-
item_vec = nn.functional.normalize(item_vec, p=2, dim=1)
|
| 198 |
-
|
| 199 |
-
# In-Batch Negatives
|
| 200 |
-
# logits[i][j] = user_i dot item_j
|
| 201 |
-
# We want logits[i][i] to be high
|
| 202 |
-
logits = torch.matmul(user_vec, item_vec.t()) # (B, B)
|
| 203 |
-
|
| 204 |
-
# Temperature scaling (optional, helps convergence)
|
| 205 |
-
logits = logits / 0.1
|
| 206 |
-
|
| 207 |
-
labels = torch.arange(len(user_vec)).to(DEVICE)
|
| 208 |
-
|
| 209 |
-
loss = criterion(logits, labels)
|
| 210 |
-
loss.backward()
|
| 211 |
-
optimizer.step()
|
| 212 |
-
|
| 213 |
-
total_loss += loss.item()
|
| 214 |
-
steps += 1
|
| 215 |
-
|
| 216 |
-
pbar.set_description(f"Epoch {epoch+1} Loss: {total_loss/steps:.4f}")
|
| 217 |
-
|
| 218 |
-
print(f"Epoch {epoch+1} finished. Avg Loss: {total_loss/steps:.4f}")
|
| 219 |
-
|
| 220 |
-
# Save Model
|
| 221 |
-
save_path = Path('data/model/recall')
|
| 222 |
-
save_path.mkdir(parents=True, exist_ok=True)
|
| 223 |
-
|
| 224 |
-
torch.save(model.state_dict(), save_path / 'youtube_dnn.pt')
|
| 225 |
-
|
| 226 |
-
# Save metadata
|
| 227 |
-
meta = {
|
| 228 |
-
'user_config': user_config,
|
| 229 |
-
'item_config': item_config,
|
| 230 |
-
'model_config': model_config,
|
| 231 |
-
'item_to_cate': item_to_cate
|
| 232 |
-
}
|
| 233 |
-
with open(save_path / 'youtube_dnn_meta.pkl', 'wb') as f:
|
| 234 |
-
pickle.dump(meta, f)
|
| 235 |
-
|
| 236 |
-
print(f"Model saved to {save_path}")
|
| 237 |
|
| 238 |
-
if __name__ ==
|
| 239 |
-
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
+
Entry script: Train YoutubeDNN Two-Tower recall model.
|
| 4 |
|
| 5 |
+
All training logic lives in YoutubeDNNRecall.fit(). This script loads data
|
| 6 |
+
and calls fit().
|
| 7 |
|
| 8 |
Usage:
|
| 9 |
python scripts/model/train_youtube_dnn.py
|
| 10 |
|
| 11 |
+
Input: data/rec/train.csv
|
| 12 |
+
Output: data/model/recall/youtube_dnn.pt, youtube_dnn_meta.pkl
|
| 13 |
+
data/rec/item_map.pkl, user_sequences.pkl
|
| 14 |
+
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
+
import sys
|
| 17 |
+
from pathlib import Path
|
|
|
|
|
|
|
| 18 |
|
| 19 |
+
sys.path.insert(0, str(Path(__file__).resolve().parent.parent.parent))
|
|
|
|
|
|
|
|
|
|
| 20 |
|
|
|
|
| 21 |
import pandas as pd
|
| 22 |
+
import logging
|
| 23 |
+
|
| 24 |
+
from src.recall.embedding import YoutubeDNNRecall
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
| 27 |
+
logger = logging.getLogger(__name__)
|
|
|
|
| 28 |
|
| 29 |
+
PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
|
| 30 |
+
TRAIN_PATH = PROJECT_ROOT / "data" / "rec" / "train.csv"
|
| 31 |
+
BOOKS_PATH = PROJECT_ROOT / "data" / "books_processed.csv"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
|
|
|
| 33 |
|
| 34 |
+
def main():
|
| 35 |
+
logger.info("Loading training data from %s...", TRAIN_PATH)
|
| 36 |
+
df = pd.read_csv(TRAIN_PATH)
|
| 37 |
+
logger.info("Loaded %d records.", len(df))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
+
model = YoutubeDNNRecall()
|
| 40 |
+
model.fit(df, books_path=BOOKS_PATH)
|
| 41 |
+
logger.info("YoutubeDNN training complete.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
+
if __name__ == "__main__":
|
| 45 |
+
main()
|
scripts/run_pipeline.py
CHANGED
|
@@ -2,8 +2,8 @@
|
|
| 2 |
"""
|
| 3 |
Unified Data Pipeline Runner
|
| 4 |
|
| 5 |
-
|
| 6 |
-
|
| 7 |
|
| 8 |
Usage:
|
| 9 |
python scripts/run_pipeline.py # Full pipeline
|
|
@@ -13,156 +13,208 @@ Usage:
|
|
| 13 |
"""
|
| 14 |
|
| 15 |
import argparse
|
| 16 |
-
import
|
| 17 |
import sys
|
| 18 |
import time
|
| 19 |
from pathlib import Path
|
| 20 |
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
|
| 47 |
def main():
|
| 48 |
-
parser = argparse.ArgumentParser(description="Run data pipeline")
|
| 49 |
-
parser.add_argument(
|
| 50 |
-
"
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
| 52 |
parser.add_argument("--skip-models", action="store_true", help="Skip model training")
|
| 53 |
parser.add_argument("--skip-index", action="store_true", help="Skip index building")
|
| 54 |
parser.add_argument("--validate-only", action="store_true", help="Only run validation")
|
| 55 |
-
parser.add_argument("--device", default=None, help="Device for ML
|
| 56 |
-
parser.add_argument("--stacking", action="store_true", help="Enable
|
| 57 |
args = parser.parse_args()
|
| 58 |
-
|
| 59 |
-
print("=" * 60)
|
| 60 |
-
print("🚀 DATA PIPELINE RUNNER")
|
| 61 |
-
print("=" * 60)
|
| 62 |
-
|
| 63 |
if args.validate_only:
|
| 64 |
-
|
|
|
|
| 65 |
return
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
if args.stage in ["all", "books"]:
|
| 73 |
-
run_script(
|
| 74 |
-
"scripts/data/clean_data.py",
|
| 75 |
-
"Cleaning text data (HTML, encoding, whitespace)",
|
| 76 |
-
args=["--backup"]
|
| 77 |
-
)
|
| 78 |
-
|
| 79 |
-
run_script(
|
| 80 |
-
"scripts/data/build_books_basic_info.py",
|
| 81 |
-
"Building books basic info"
|
| 82 |
-
)
|
| 83 |
-
|
| 84 |
-
device_args = ["--device", args.device] if args.device else []
|
| 85 |
-
run_script(
|
| 86 |
-
"scripts/data/generate_emotions.py",
|
| 87 |
-
"Generating emotion scores",
|
| 88 |
-
args=device_args
|
| 89 |
-
)
|
| 90 |
-
|
| 91 |
-
run_script(
|
| 92 |
-
"scripts/data/generate_tags.py",
|
| 93 |
-
"Generating tags"
|
| 94 |
-
)
|
| 95 |
-
|
| 96 |
-
run_script(
|
| 97 |
-
"scripts/data/chunk_reviews.py",
|
| 98 |
-
"Chunking reviews for Small-to-Big"
|
| 99 |
-
)
|
| 100 |
-
|
| 101 |
-
# ==========================================================================
|
| 102 |
-
# Stage 2: RecSys Data Preparation
|
| 103 |
-
# ==========================================================================
|
| 104 |
-
if args.stage in ["all", "rec"]:
|
| 105 |
-
run_script(
|
| 106 |
-
"scripts/data/split_rec_data.py",
|
| 107 |
-
"Splitting train/val/test data"
|
| 108 |
-
)
|
| 109 |
-
|
| 110 |
-
run_script(
|
| 111 |
-
"scripts/data/build_sequences.py",
|
| 112 |
-
"Building user sequences"
|
| 113 |
-
)
|
| 114 |
-
|
| 115 |
-
# ==========================================================================
|
| 116 |
-
# Stage 3: Index Building
|
| 117 |
-
# ==========================================================================
|
| 118 |
-
if args.stage in ["all", "index"] and not args.skip_index:
|
| 119 |
-
run_script(
|
| 120 |
-
"scripts/init_sqlite_db.py",
|
| 121 |
-
"Building SQLite metadata (books.db)"
|
| 122 |
-
)
|
| 123 |
-
run_script(
|
| 124 |
-
"scripts/data/init_dual_index.py",
|
| 125 |
-
"Building chunk vector index"
|
| 126 |
-
)
|
| 127 |
-
|
| 128 |
-
# ==========================================================================
|
| 129 |
-
# Stage 4: Model Training
|
| 130 |
-
# ==========================================================================
|
| 131 |
-
if args.stage in ["all", "models"] and not args.skip_models:
|
| 132 |
-
run_script(
|
| 133 |
-
"scripts/model/build_recall_models.py",
|
| 134 |
-
"Building ItemCF/UserCF models"
|
| 135 |
-
)
|
| 136 |
-
|
| 137 |
-
run_script(
|
| 138 |
-
"scripts/model/train_youtube_dnn.py",
|
| 139 |
-
"Training YoutubeDNN (requires GPU)"
|
| 140 |
-
)
|
| 141 |
-
|
| 142 |
-
run_script(
|
| 143 |
-
"scripts/model/train_sasrec.py",
|
| 144 |
-
"Training SASRec (requires GPU)"
|
| 145 |
-
)
|
| 146 |
-
|
| 147 |
-
ranker_args = ["--stacking"] if args.stacking else []
|
| 148 |
-
run_script(
|
| 149 |
-
"scripts/model/train_ranker.py",
|
| 150 |
-
"Training LGBMRanker (Stacking: {})".format("ON" if args.stacking else "OFF"),
|
| 151 |
-
args=ranker_args
|
| 152 |
-
)
|
| 153 |
-
|
| 154 |
-
# ==========================================================================
|
| 155 |
-
# Final Validation
|
| 156 |
-
# ==========================================================================
|
| 157 |
-
run_script(
|
| 158 |
-
"scripts/data/validate_data.py",
|
| 159 |
-
"Final validation"
|
| 160 |
)
|
| 161 |
-
|
| 162 |
-
elapsed_total = time.time() - start_total
|
| 163 |
-
print("\n" + "=" * 60)
|
| 164 |
-
print(f"🎉 PIPELINE COMPLETED in {elapsed_total/60:.1f} minutes")
|
| 165 |
-
print("=" * 60)
|
| 166 |
|
| 167 |
|
| 168 |
if __name__ == "__main__":
|
|
|
|
| 2 |
"""
|
| 3 |
Unified Data Pipeline Runner
|
| 4 |
|
| 5 |
+
Orchestrates Data Cleaning -> Training -> Evaluation using direct Python imports.
|
| 6 |
+
No subprocess calls. All logic invoked via Module.run() or src classes.
|
| 7 |
|
| 8 |
Usage:
|
| 9 |
python scripts/run_pipeline.py # Full pipeline
|
|
|
|
| 13 |
"""
|
| 14 |
|
| 15 |
import argparse
|
| 16 |
+
import logging
|
| 17 |
import sys
|
| 18 |
import time
|
| 19 |
from pathlib import Path
|
| 20 |
|
| 21 |
+
# Ensure project root is on path
|
| 22 |
+
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
| 23 |
+
sys.path.insert(0, str(PROJECT_ROOT))
|
| 24 |
+
|
| 25 |
+
logging.basicConfig(
|
| 26 |
+
level=logging.INFO,
|
| 27 |
+
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
| 28 |
+
datefmt="%H:%M:%S",
|
| 29 |
+
)
|
| 30 |
+
logger = logging.getLogger("pipeline")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class Pipeline:
|
| 34 |
+
"""
|
| 35 |
+
Manages the full data pipeline: Data Cleaning -> Training -> Evaluation.
|
| 36 |
+
|
| 37 |
+
All stages use direct Python imports; no subprocess.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
project_root: Path = PROJECT_ROOT,
|
| 43 |
+
device: str | None = None,
|
| 44 |
+
skip_models: bool = False,
|
| 45 |
+
skip_index: bool = False,
|
| 46 |
+
stacking: bool = False,
|
| 47 |
+
):
|
| 48 |
+
self.project_root = Path(project_root)
|
| 49 |
+
self.data_dir = self.project_root / "data"
|
| 50 |
+
self.rec_dir = self.data_dir / "rec"
|
| 51 |
+
self.model_dir = self.data_dir / "model"
|
| 52 |
+
self.device = device
|
| 53 |
+
self.skip_models = skip_models
|
| 54 |
+
self.skip_index = skip_index
|
| 55 |
+
self.stacking = stacking
|
| 56 |
+
|
| 57 |
+
def _run_step(self, name: str, fn, *args, **kwargs):
|
| 58 |
+
"""Run a step with timing log."""
|
| 59 |
+
logger.info("▶ %s", name)
|
| 60 |
+
start = time.time()
|
| 61 |
+
fn(*args, **kwargs)
|
| 62 |
+
logger.info(" ✓ Done in %.1fs", time.time() - start)
|
| 63 |
+
|
| 64 |
+
def run_data_cleaning(self, stage: str = "all") -> None:
|
| 65 |
+
"""Stage 1: Book data processing."""
|
| 66 |
+
if stage not in ("all", "books"):
|
| 67 |
+
return
|
| 68 |
+
|
| 69 |
+
from scripts.data.clean_data import run as clean_run
|
| 70 |
+
self._run_step("Clean text data", clean_run, backup=True)
|
| 71 |
+
|
| 72 |
+
from scripts.data.build_books_basic_info import run as build_run
|
| 73 |
+
raw_dir = self.data_dir / "raw"
|
| 74 |
+
self._run_step("Build books basic info", build_run,
|
| 75 |
+
books_path=raw_dir / "books_data.csv",
|
| 76 |
+
ratings_path=raw_dir / "Books_rating.csv",
|
| 77 |
+
output_path=self.data_dir / "books_basic_info.csv",
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
from scripts.data.generate_emotions import run as emotions_run
|
| 81 |
+
self._run_step("Generate emotion scores", emotions_run, device=self.device)
|
| 82 |
+
|
| 83 |
+
from scripts.data.generate_tags import run as tags_run
|
| 84 |
+
self._run_step("Generate tags", tags_run)
|
| 85 |
+
|
| 86 |
+
from scripts.data.chunk_reviews import chunk_reviews
|
| 87 |
+
self._run_step("Chunk reviews", chunk_reviews,
|
| 88 |
+
str(self.data_dir / "review_highlights.txt"),
|
| 89 |
+
str(self.data_dir / "review_chunks.jsonl"),
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
def run_rec_preparation(self, stage: str = "all") -> None:
|
| 93 |
+
"""Stage 2: RecSys data preparation."""
|
| 94 |
+
if stage not in ("all", "rec"):
|
| 95 |
+
return
|
| 96 |
+
|
| 97 |
+
from scripts.data.split_rec_data import run as split_run
|
| 98 |
+
self._run_step("Split train/val/test", split_run,
|
| 99 |
+
data_path=self.data_dir / "raw" / "Books_rating.csv",
|
| 100 |
+
output_dir=self.rec_dir,
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
from scripts.data.build_sequences import build_sequences
|
| 104 |
+
self._run_step("Build user sequences", build_sequences, str(self.rec_dir))
|
| 105 |
+
|
| 106 |
+
def run_index_building(self, stage: str = "all") -> None:
|
| 107 |
+
"""Stage 3: Index building."""
|
| 108 |
+
if stage not in ("all", "index") or self.skip_index:
|
| 109 |
+
return
|
| 110 |
+
|
| 111 |
+
from scripts.init_sqlite_db import init_sqlite_db
|
| 112 |
+
self._run_step("Build SQLite metadata (books.db)", init_sqlite_db, str(self.data_dir))
|
| 113 |
+
|
| 114 |
+
from scripts.data.init_dual_index import init_chunk_index
|
| 115 |
+
self._run_step("Build chunk vector index", init_chunk_index)
|
| 116 |
+
|
| 117 |
+
def run_training(self, stage: str = "all") -> None:
|
| 118 |
+
"""Stage 4: Model training via src imports."""
|
| 119 |
+
if stage not in ("all", "models") or self.skip_models:
|
| 120 |
+
return
|
| 121 |
+
|
| 122 |
+
train_path = self.rec_dir / "train.csv"
|
| 123 |
+
if not train_path.exists():
|
| 124 |
+
logger.warning("train.csv not found, skipping model training")
|
| 125 |
+
return
|
| 126 |
+
|
| 127 |
+
import pandas as pd
|
| 128 |
+
df = pd.read_csv(train_path)
|
| 129 |
+
|
| 130 |
+
from src.recall.itemcf import ItemCF
|
| 131 |
+
self._run_step("Train ItemCF", lambda: ItemCF().fit(df))
|
| 132 |
+
|
| 133 |
+
from src.recall.usercf import UserCF
|
| 134 |
+
self._run_step("Train UserCF", lambda: UserCF().fit(df))
|
| 135 |
+
|
| 136 |
+
from src.recall.swing import Swing
|
| 137 |
+
self._run_step("Train Swing", lambda: Swing().fit(df))
|
| 138 |
+
|
| 139 |
+
from src.recall.popularity import PopularityRecall
|
| 140 |
+
self._run_step("Train Popularity", lambda: PopularityRecall().fit(df))
|
| 141 |
+
|
| 142 |
+
from src.recall.item2vec import Item2Vec
|
| 143 |
+
self._run_step("Train Item2Vec", lambda: Item2Vec().fit(df))
|
| 144 |
+
|
| 145 |
+
from src.recall.embedding import YoutubeDNNRecall
|
| 146 |
+
self._run_step("Train YoutubeDNN", lambda: YoutubeDNNRecall().fit(
|
| 147 |
+
df, books_path=self.data_dir / "books_processed.csv"
|
| 148 |
+
))
|
| 149 |
+
|
| 150 |
+
from src.recall.sasrec_recall import SASRecRecall
|
| 151 |
+
self._run_step("Train SASRec", lambda: SASRecRecall().fit(df))
|
| 152 |
+
|
| 153 |
+
from scripts.model.train_ranker import train_ranker, train_stacking
|
| 154 |
+
self._run_step("Train Ranker", train_stacking if self.stacking else train_ranker)
|
| 155 |
+
|
| 156 |
+
def run_evaluation(self) -> None:
|
| 157 |
+
"""Stage 5: Validation."""
|
| 158 |
+
def _validate():
|
| 159 |
+
from scripts.data.validate_data import (
|
| 160 |
+
validate_raw, validate_processed, validate_rec,
|
| 161 |
+
validate_index, validate_models,
|
| 162 |
+
)
|
| 163 |
+
validate_raw()
|
| 164 |
+
validate_processed()
|
| 165 |
+
validate_rec()
|
| 166 |
+
validate_index()
|
| 167 |
+
validate_models()
|
| 168 |
+
|
| 169 |
+
self._run_step("Validate pipeline", _validate)
|
| 170 |
+
|
| 171 |
+
def run(self, stage: str = "all") -> None:
|
| 172 |
+
"""Execute full pipeline: Data Cleaning -> Training -> Evaluation."""
|
| 173 |
+
logger.info("=" * 60)
|
| 174 |
+
logger.info("Pipeline: Data Cleaning -> Training -> Evaluation")
|
| 175 |
+
logger.info("=" * 60)
|
| 176 |
+
|
| 177 |
+
start_total = time.time()
|
| 178 |
+
|
| 179 |
+
self.run_data_cleaning(stage)
|
| 180 |
+
self.run_rec_preparation(stage)
|
| 181 |
+
self.run_index_building(stage)
|
| 182 |
+
self.run_training(stage)
|
| 183 |
+
self.run_evaluation()
|
| 184 |
+
|
| 185 |
+
elapsed = time.time() - start_total
|
| 186 |
+
logger.info("=" * 60)
|
| 187 |
+
logger.info("Pipeline completed in %.1f min", elapsed / 60)
|
| 188 |
+
logger.info("=" * 60)
|
| 189 |
|
| 190 |
|
| 191 |
def main():
|
| 192 |
+
parser = argparse.ArgumentParser(description="Run data pipeline (no subprocess)")
|
| 193 |
+
parser.add_argument(
|
| 194 |
+
"--stage",
|
| 195 |
+
choices=["all", "books", "rec", "index", "models"],
|
| 196 |
+
default="all",
|
| 197 |
+
help="Which stage to run",
|
| 198 |
+
)
|
| 199 |
parser.add_argument("--skip-models", action="store_true", help="Skip model training")
|
| 200 |
parser.add_argument("--skip-index", action="store_true", help="Skip index building")
|
| 201 |
parser.add_argument("--validate-only", action="store_true", help="Only run validation")
|
| 202 |
+
parser.add_argument("--device", default=None, help="Device for ML (cpu/cuda/mps)")
|
| 203 |
+
parser.add_argument("--stacking", action="store_true", help="Enable stacking ranker")
|
| 204 |
args = parser.parse_args()
|
| 205 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
if args.validate_only:
|
| 207 |
+
logger.info("Validation only")
|
| 208 |
+
Pipeline().run_evaluation()
|
| 209 |
return
|
| 210 |
+
|
| 211 |
+
pipeline = Pipeline(
|
| 212 |
+
device=args.device,
|
| 213 |
+
skip_models=args.skip_models,
|
| 214 |
+
skip_index=args.skip_index,
|
| 215 |
+
stacking=args.stacking,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
)
|
| 217 |
+
pipeline.run(stage=args.stage)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
|
| 219 |
|
| 220 |
if __name__ == "__main__":
|
src/core/model_loader.py
CHANGED
|
@@ -13,7 +13,7 @@ Usage:
|
|
| 13 |
import os
|
| 14 |
import logging
|
| 15 |
from pathlib import Path
|
| 16 |
-
from huggingface_hub import
|
| 17 |
from src.config import DATA_DIR
|
| 18 |
|
| 19 |
logger = logging.getLogger(__name__)
|
|
|
|
| 13 |
import os
|
| 14 |
import logging
|
| 15 |
from pathlib import Path
|
| 16 |
+
from huggingface_hub import snapshot_download
|
| 17 |
from src.config import DATA_DIR
|
| 18 |
|
| 19 |
logger = logging.getLogger(__name__)
|
src/data/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Data access layer for book recommendation system."""
|
| 2 |
+
|
| 3 |
+
from src.data.repository import DataRepository, data_repository
|
| 4 |
+
|
| 5 |
+
__all__ = ["DataRepository", "data_repository"]
|
src/data/repository.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Unified Data Repository for book recommendation system.
|
| 3 |
+
|
| 4 |
+
Centralizes all core data access: books metadata, user history, etc.
|
| 5 |
+
Replaces scattered pandas.read_csv and pickle.load calls across services.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Any, Dict, List, Optional
|
| 10 |
+
import sqlite3
|
| 11 |
+
|
| 12 |
+
from src.config import DATA_DIR
|
| 13 |
+
from src.core.metadata_store import metadata_store
|
| 14 |
+
from src.utils import setup_logger
|
| 15 |
+
|
| 16 |
+
logger = setup_logger(__name__)
|
| 17 |
+
|
| 18 |
+
# Core data file paths
|
| 19 |
+
BOOKS_DB_PATH = DATA_DIR / "books.db"
|
| 20 |
+
BOOKS_PROCESSED_CSV = DATA_DIR / "books_processed.csv"
|
| 21 |
+
RECALL_MODELS_DB = DATA_DIR / "recall_models.db"
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class DataRepository:
|
| 25 |
+
"""
|
| 26 |
+
Singleton data access layer. Manages loading of books_processed.csv,
|
| 27 |
+
books.db, recall_models.db (user_history), etc.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
_instance: Optional["DataRepository"] = None
|
| 31 |
+
|
| 32 |
+
def __new__(cls) -> "DataRepository":
|
| 33 |
+
if cls._instance is None:
|
| 34 |
+
cls._instance = super(DataRepository, cls).__new__(cls)
|
| 35 |
+
cls._instance._initialized = False
|
| 36 |
+
return cls._instance
|
| 37 |
+
|
| 38 |
+
def __init__(self) -> None:
|
| 39 |
+
if getattr(self, "_initialized", False):
|
| 40 |
+
return
|
| 41 |
+
self._initialized = True
|
| 42 |
+
self._recall_conn: Optional[sqlite3.Connection] = None
|
| 43 |
+
logger.info("DataRepository: Initialized (singleton)")
|
| 44 |
+
|
| 45 |
+
def _get_recall_connection(self) -> Optional[sqlite3.Connection]:
|
| 46 |
+
"""Lazy SQLite connection for recall_models.db."""
|
| 47 |
+
if self._recall_conn is None:
|
| 48 |
+
if not RECALL_MODELS_DB.exists():
|
| 49 |
+
logger.warning(f"recall_models.db not found at {RECALL_MODELS_DB}")
|
| 50 |
+
return None
|
| 51 |
+
try:
|
| 52 |
+
self._recall_conn = sqlite3.connect(
|
| 53 |
+
str(RECALL_MODELS_DB), check_same_thread=False
|
| 54 |
+
)
|
| 55 |
+
except sqlite3.Error as e:
|
| 56 |
+
logger.error(f"DataRepository: Failed to connect to recall DB: {e}")
|
| 57 |
+
return self._recall_conn
|
| 58 |
+
|
| 59 |
+
def get_book_metadata(self, isbn: str) -> Optional[Dict[str, Any]]:
|
| 60 |
+
"""
|
| 61 |
+
Get book metadata by ISBN.
|
| 62 |
+
|
| 63 |
+
Uses MetadataStore (books.db) as primary source. Returns None if not found.
|
| 64 |
+
"""
|
| 65 |
+
meta = metadata_store.get_book_metadata(str(isbn))
|
| 66 |
+
return meta if meta else None
|
| 67 |
+
|
| 68 |
+
def get_user_history(self, user_id: str) -> List[str]:
|
| 69 |
+
"""
|
| 70 |
+
Get user's interaction history (ISBNs) from recall_models.db.
|
| 71 |
+
|
| 72 |
+
Used by recommendation algorithms (ItemCF, etc.). Returns empty list if
|
| 73 |
+
DB unavailable or user has no history.
|
| 74 |
+
"""
|
| 75 |
+
conn = self._get_recall_connection()
|
| 76 |
+
if not conn:
|
| 77 |
+
return []
|
| 78 |
+
try:
|
| 79 |
+
cursor = conn.cursor()
|
| 80 |
+
cursor.execute(
|
| 81 |
+
"SELECT isbn FROM user_history WHERE user_id = ?", (user_id,)
|
| 82 |
+
)
|
| 83 |
+
return [row[0] for row in cursor.fetchall()]
|
| 84 |
+
except sqlite3.Error as e:
|
| 85 |
+
logger.error(f"DataRepository: get_user_history failed: {e}")
|
| 86 |
+
return []
|
| 87 |
+
|
| 88 |
+
def get_all_categories(self) -> List[str]:
|
| 89 |
+
"""Get unique book categories. Delegates to MetadataStore."""
|
| 90 |
+
return metadata_store.get_all_categories()
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
# Global singleton instance
|
| 94 |
+
data_repository = DataRepository()
|
src/init_db.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
import os
|
| 2 |
import shutil
|
| 3 |
import sys
|
| 4 |
-
import torch
|
| 5 |
from pathlib import Path
|
| 6 |
|
| 7 |
# Add project root to Python path
|
|
@@ -21,20 +20,10 @@ def init_db():
|
|
| 21 |
# FIX: Disable Tokenizers Parallelism to prevent deadlocks on macOS
|
| 22 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 23 |
|
| 24 |
-
# Force CPU for data ingestion to avoid MPS (Metal) async hangs during long processing
|
| 25 |
-
#
|
| 26 |
device = "cpu"
|
| 27 |
print("🐢 Forcing CPU for stable database ingestion (prevents macOS Freezes).")
|
| 28 |
-
|
| 29 |
-
# if torch.backends.mps.is_available():
|
| 30 |
-
# device = "mps"
|
| 31 |
-
# print("⚡️ MacOS GPU (MPS) Detected! switching to GPU acceleration.")
|
| 32 |
-
# elif torch.cuda.is_available():
|
| 33 |
-
# device = "cuda"
|
| 34 |
-
# print("⚡️ NVIDIA GPU (CUDA) Detected!")
|
| 35 |
-
# else:
|
| 36 |
-
# device = "cpu"
|
| 37 |
-
# print("🐢 No GPU detected, running on CPU (this might be slow).")
|
| 38 |
|
| 39 |
# 1. Clear existing DB if any (to avoid duplicates/corruption)
|
| 40 |
if CHROMA_DB_DIR.exists():
|
|
|
|
| 1 |
import os
|
| 2 |
import shutil
|
| 3 |
import sys
|
|
|
|
| 4 |
from pathlib import Path
|
| 5 |
|
| 6 |
# Add project root to Python path
|
|
|
|
| 20 |
# FIX: Disable Tokenizers Parallelism to prevent deadlocks on macOS
|
| 21 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 22 |
|
| 23 |
+
# Force CPU for data ingestion to avoid MPS (Metal) async hangs during long processing.
|
| 24 |
+
# Reliability is key for building the DB; GPU acceleration is only needed for inference.
|
| 25 |
device = "cpu"
|
| 26 |
print("🐢 Forcing CPU for stable database ingestion (prevents macOS Freezes).")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
# 1. Clear existing DB if any (to avoid duplicates/corruption)
|
| 29 |
if CHROMA_DB_DIR.exists():
|
src/main.py
CHANGED
|
@@ -15,8 +15,6 @@ from src.user.profile_store import (
|
|
| 15 |
update_book_rating, update_reading_status, update_book_comment,
|
| 16 |
get_favorites_with_metadata, get_reading_stats
|
| 17 |
)
|
| 18 |
-
from src.marketing.persona import build_persona
|
| 19 |
-
from src.marketing.highlights import generate_highlights
|
| 20 |
from src.api.chat import router as chat_router # ✨ NEW
|
| 21 |
from src.services.chat_service import chat_service # ✨ NEW
|
| 22 |
from src.services.recommend_service import RecommendationService # ✨ NEW
|
|
@@ -236,9 +234,6 @@ async def favorites_list(user_id: str):
|
|
| 236 |
favorites_meta = get_favorites_with_metadata(user_id)
|
| 237 |
# ENGINEERING IMPROVEMENT: Zero-RAM Lookup
|
| 238 |
from src.core.metadata_store import metadata_store
|
| 239 |
-
|
| 240 |
-
results = []
|
| 241 |
-
# Lazy load fetcher (Handled inside utils now)
|
| 242 |
from src.utils import enrich_book_metadata
|
| 243 |
|
| 244 |
results = []
|
|
|
|
| 15 |
update_book_rating, update_reading_status, update_book_comment,
|
| 16 |
get_favorites_with_metadata, get_reading_stats
|
| 17 |
)
|
|
|
|
|
|
|
| 18 |
from src.api.chat import router as chat_router # ✨ NEW
|
| 19 |
from src.services.chat_service import chat_service # ✨ NEW
|
| 20 |
from src.services.recommend_service import RecommendationService # ✨ NEW
|
|
|
|
| 234 |
favorites_meta = get_favorites_with_metadata(user_id)
|
| 235 |
# ENGINEERING IMPROVEMENT: Zero-RAM Lookup
|
| 236 |
from src.core.metadata_store import metadata_store
|
|
|
|
|
|
|
|
|
|
| 237 |
from src.utils import enrich_book_metadata
|
| 238 |
|
| 239 |
results = []
|
src/marketing/persona.py
CHANGED
|
@@ -1,15 +1,17 @@
|
|
| 1 |
from collections import Counter
|
| 2 |
-
from typing import Dict, List, Any
|
| 3 |
-
import pandas as pd
|
| 4 |
|
| 5 |
from src.utils import setup_logger
|
| 6 |
|
| 7 |
logger = setup_logger(__name__)
|
| 8 |
|
| 9 |
|
| 10 |
-
def build_persona(fav_isbns: List[str], books:
|
| 11 |
-
"""
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
| 13 |
return {
|
| 14 |
"summary": "No profile yet. Start by adding your favorite books to see personalized recommendations.",
|
| 15 |
"top_authors": [],
|
|
|
|
| 1 |
from collections import Counter
|
| 2 |
+
from typing import Dict, List, Any, Optional
|
|
|
|
| 3 |
|
| 4 |
from src.utils import setup_logger
|
| 5 |
|
| 6 |
logger = setup_logger(__name__)
|
| 7 |
|
| 8 |
|
| 9 |
+
def build_persona(fav_isbns: List[str], books: Optional[Any] = None) -> Dict[str, Any]:
|
| 10 |
+
"""
|
| 11 |
+
Aggregate a simple persona from favorites: top authors and categories.
|
| 12 |
+
Uses MetadataStore for lookups; the books param is legacy and unused.
|
| 13 |
+
"""
|
| 14 |
+
if not fav_isbns:
|
| 15 |
return {
|
| 16 |
"summary": "No profile yet. Start by adding your favorite books to see personalized recommendations.",
|
| 17 |
"top_authors": [],
|
src/marketing/personalized_highlight.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
import json
|
| 2 |
-
import pandas as pd
|
| 3 |
from src.marketing.persona import build_persona
|
| 4 |
from src.marketing.highlights import generate_highlights
|
| 5 |
|
|
|
|
| 1 |
import json
|
|
|
|
| 2 |
from src.marketing.persona import build_persona
|
| 3 |
from src.marketing.highlights import generate_highlights
|
| 4 |
|
src/marketing/verify_p3.py
CHANGED
|
@@ -7,7 +7,7 @@ import torch
|
|
| 7 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 8 |
from peft import PeftModel
|
| 9 |
from modelscope import snapshot_download
|
| 10 |
-
from guardrails import ContentGuardrail
|
| 11 |
|
| 12 |
# Config
|
| 13 |
BASE_MODEL_ID = "qwen/Qwen2-7B-Instruct"
|
|
|
|
| 7 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 8 |
from peft import PeftModel
|
| 9 |
from modelscope import snapshot_download
|
| 10 |
+
from src.marketing.guardrails import ContentGuardrail
|
| 11 |
|
| 12 |
# Config
|
| 13 |
BASE_MODEL_ID = "qwen/Qwen2-7B-Instruct"
|
src/recall/embedding.py
CHANGED
|
@@ -5,16 +5,63 @@ V2.7: Replaced torch.matmul brute-force search with Faiss IndexFlatIP
|
|
| 5 |
for SIMD-accelerated inner-product retrieval.
|
| 6 |
"""
|
| 7 |
|
| 8 |
-
import torch
|
| 9 |
-
import numpy as np
|
| 10 |
import pickle
|
| 11 |
import logging
|
| 12 |
-
import faiss
|
| 13 |
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
from src.recall.youtube_dnn import YoutubeDNN
|
| 15 |
|
| 16 |
logger = logging.getLogger(__name__)
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
class YoutubeDNNRecall:
|
| 19 |
def __init__(self, data_dir='data/rec', model_dir='data/model/recall'):
|
| 20 |
self.data_dir = Path(data_dir)
|
|
@@ -33,7 +80,117 @@ class YoutubeDNNRecall:
|
|
| 33 |
self.id_to_item = {}
|
| 34 |
self.meta = None
|
| 35 |
|
| 36 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
try:
|
| 38 |
logger.info("Loading YoutubeDNN model...")
|
| 39 |
# Load metadata
|
|
|
|
| 5 |
for SIMD-accelerated inner-product retrieval.
|
| 6 |
"""
|
| 7 |
|
|
|
|
|
|
|
| 8 |
import pickle
|
| 9 |
import logging
|
|
|
|
| 10 |
from pathlib import Path
|
| 11 |
+
from typing import Optional
|
| 12 |
+
|
| 13 |
+
import faiss
|
| 14 |
+
import numpy as np
|
| 15 |
+
import pandas as pd
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
import torch.optim as optim
|
| 19 |
+
from torch.utils.data import Dataset, DataLoader
|
| 20 |
+
from tqdm import tqdm
|
| 21 |
+
|
| 22 |
+
from src.recall.sequence_utils import build_sequences_from_df
|
| 23 |
from src.recall.youtube_dnn import YoutubeDNN
|
| 24 |
|
| 25 |
logger = logging.getLogger(__name__)
|
| 26 |
|
| 27 |
+
|
| 28 |
+
class _RetrievalDataset(Dataset):
|
| 29 |
+
"""Internal dataset for YoutubeDNN training (history -> target)."""
|
| 30 |
+
|
| 31 |
+
def __init__(self, user_seqs: dict, item_to_cate: dict, default_cate: int, max_history: int):
|
| 32 |
+
self.samples: list[tuple[list[int], int, int]] = []
|
| 33 |
+
self.item_to_cate = item_to_cate
|
| 34 |
+
self.default_cate = default_cate
|
| 35 |
+
self.max_history = max_history
|
| 36 |
+
|
| 37 |
+
for user, seq in user_seqs.items():
|
| 38 |
+
if len(seq) < 3:
|
| 39 |
+
continue
|
| 40 |
+
train_seq = seq[:-2]
|
| 41 |
+
for i in range(1, len(train_seq)):
|
| 42 |
+
target = train_seq[i]
|
| 43 |
+
history = train_seq[:i]
|
| 44 |
+
if len(history) > max_history:
|
| 45 |
+
history = history[-max_history:]
|
| 46 |
+
target_cate = item_to_cate.get(target, default_cate)
|
| 47 |
+
self.samples.append((history, target, target_cate))
|
| 48 |
+
|
| 49 |
+
def __len__(self) -> int:
|
| 50 |
+
return len(self.samples)
|
| 51 |
+
|
| 52 |
+
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 53 |
+
history, target, target_cate = self.samples[idx]
|
| 54 |
+
padded = np.zeros(self.max_history, dtype=np.int64)
|
| 55 |
+
length = min(len(history), self.max_history)
|
| 56 |
+
if length > 0:
|
| 57 |
+
padded[:length] = history[-length:]
|
| 58 |
+
return (
|
| 59 |
+
torch.LongTensor(padded),
|
| 60 |
+
torch.tensor(target, dtype=torch.long),
|
| 61 |
+
torch.tensor(target_cate, dtype=torch.long),
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
class YoutubeDNNRecall:
|
| 66 |
def __init__(self, data_dir='data/rec', model_dir='data/model/recall'):
|
| 67 |
self.data_dir = Path(data_dir)
|
|
|
|
| 80 |
self.id_to_item = {}
|
| 81 |
self.meta = None
|
| 82 |
|
| 83 |
+
def fit(
|
| 84 |
+
self,
|
| 85 |
+
df: pd.DataFrame,
|
| 86 |
+
books_path: Optional[Path] = None,
|
| 87 |
+
epochs: int = 10,
|
| 88 |
+
batch_size: int = 512,
|
| 89 |
+
lr: float = 0.001,
|
| 90 |
+
embed_dim: int = 64,
|
| 91 |
+
max_history: int = 20,
|
| 92 |
+
) -> "YoutubeDNNRecall":
|
| 93 |
+
"""
|
| 94 |
+
Train YoutubeDNN from interaction DataFrame. Builds sequences internally.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
df: [user_id, isbn, timestamp] (timestamp optional)
|
| 98 |
+
books_path: Path to books_processed.csv for categories. If None, uses default.
|
| 99 |
+
epochs, batch_size, lr: Training hyperparameters.
|
| 100 |
+
"""
|
| 101 |
+
logger.info("Building sequences from DataFrame...")
|
| 102 |
+
user_seqs, item_map = build_sequences_from_df(df, max_len=50)
|
| 103 |
+
|
| 104 |
+
self.item_map = item_map
|
| 105 |
+
self.id_to_item = {v: k for k, v in item_map.items()}
|
| 106 |
+
vocab_size = len(item_map) + 1
|
| 107 |
+
|
| 108 |
+
# Category map
|
| 109 |
+
cate_map: dict[str, int] = {"<PAD>": 0, "<UNK>": 1}
|
| 110 |
+
item_to_cate: dict[int, int] = {}
|
| 111 |
+
default_cate = 1
|
| 112 |
+
|
| 113 |
+
books_path = Path(books_path) if books_path else self.data_dir.parent / "books_processed.csv"
|
| 114 |
+
if books_path.exists():
|
| 115 |
+
books_df = pd.read_csv(books_path, usecols=["isbn13", "simple_categories"])
|
| 116 |
+
books_df["isbn"] = books_df["isbn13"].astype(str)
|
| 117 |
+
for _, row in books_df.iterrows():
|
| 118 |
+
isbn = str(row["isbn"])
|
| 119 |
+
if isbn in item_map:
|
| 120 |
+
iid = item_map[isbn]
|
| 121 |
+
cates = str(row["simple_categories"]).split(";")
|
| 122 |
+
main_cate = cates[0].strip() if cates else "Unknown"
|
| 123 |
+
if main_cate not in cate_map:
|
| 124 |
+
cate_map[main_cate] = len(cate_map)
|
| 125 |
+
item_to_cate[iid] = cate_map[main_cate]
|
| 126 |
+
default_cate = cate_map.get("Unknown", 1)
|
| 127 |
+
|
| 128 |
+
for iid in range(1, vocab_size):
|
| 129 |
+
if iid not in item_to_cate:
|
| 130 |
+
item_to_cate[iid] = default_cate
|
| 131 |
+
|
| 132 |
+
cate_vocab_size = len(cate_map)
|
| 133 |
+
|
| 134 |
+
user_config = {"vocab_size": vocab_size, "embed_dim": embed_dim, "history_len": max_history}
|
| 135 |
+
item_config = {
|
| 136 |
+
"vocab_size": vocab_size,
|
| 137 |
+
"embed_dim": embed_dim,
|
| 138 |
+
"cate_vocab_size": cate_vocab_size,
|
| 139 |
+
"cate_embed_dim": 32,
|
| 140 |
+
}
|
| 141 |
+
model_config = {"hidden_dims": [128, 64], "dropout": 0.1}
|
| 142 |
+
|
| 143 |
+
dataset = _RetrievalDataset(user_seqs, item_to_cate, default_cate, max_history)
|
| 144 |
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
|
| 145 |
+
|
| 146 |
+
self.model = YoutubeDNN(user_config, item_config, model_config).to(self.device)
|
| 147 |
+
optimizer = optim.Adam(self.model.parameters(), lr=lr)
|
| 148 |
+
criterion = nn.CrossEntropyLoss()
|
| 149 |
+
|
| 150 |
+
logger.info("Training YoutubeDNN...")
|
| 151 |
+
self.model.train()
|
| 152 |
+
for epoch in range(epochs):
|
| 153 |
+
total_loss = 0.0
|
| 154 |
+
steps = 0
|
| 155 |
+
for history, target_item, target_cate in tqdm(dataloader, desc=f"Epoch {epoch+1}"):
|
| 156 |
+
history = history.to(self.device)
|
| 157 |
+
target_item = target_item.to(self.device)
|
| 158 |
+
target_cate = target_cate.to(self.device)
|
| 159 |
+
optimizer.zero_grad()
|
| 160 |
+
user_vec = self.model.user_tower(history)
|
| 161 |
+
item_vec = self.model.item_tower(target_item, target_cate)
|
| 162 |
+
user_vec = nn.functional.normalize(user_vec, p=2, dim=1)
|
| 163 |
+
item_vec = nn.functional.normalize(item_vec, p=2, dim=1)
|
| 164 |
+
logits = torch.matmul(user_vec, item_vec.t()) / 0.1
|
| 165 |
+
labels = torch.arange(len(user_vec)).to(self.device)
|
| 166 |
+
loss = criterion(logits, labels)
|
| 167 |
+
loss.backward()
|
| 168 |
+
optimizer.step()
|
| 169 |
+
total_loss += loss.item()
|
| 170 |
+
steps += 1
|
| 171 |
+
logger.info(f"Epoch {epoch+1} Loss: {total_loss/steps:.4f}")
|
| 172 |
+
|
| 173 |
+
self.save_dir.mkdir(parents=True, exist_ok=True)
|
| 174 |
+
torch.save(self.model.state_dict(), self.model_dir / "youtube_dnn.pt")
|
| 175 |
+
self.meta = {
|
| 176 |
+
"user_config": user_config,
|
| 177 |
+
"item_config": item_config,
|
| 178 |
+
"model_config": model_config,
|
| 179 |
+
"item_to_cate": item_to_cate,
|
| 180 |
+
}
|
| 181 |
+
with open(self.model_dir / "youtube_dnn_meta.pkl", "wb") as f:
|
| 182 |
+
pickle.dump(self.meta, f)
|
| 183 |
+
|
| 184 |
+
self.data_dir.mkdir(parents=True, exist_ok=True)
|
| 185 |
+
with open(self.data_dir / "item_map.pkl", "wb") as f:
|
| 186 |
+
pickle.dump(self.item_map, f)
|
| 187 |
+
with open(self.data_dir / "user_sequences.pkl", "wb") as f:
|
| 188 |
+
pickle.dump(user_seqs, f)
|
| 189 |
+
|
| 190 |
+
logger.info(f"YoutubeDNN saved to {self.model_dir}")
|
| 191 |
+
return self
|
| 192 |
+
|
| 193 |
+
def load(self) -> bool:
|
| 194 |
try:
|
| 195 |
logger.info("Loading YoutubeDNN model...")
|
| 196 |
# Load metadata
|
src/recall/fusion.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
| 1 |
import logging
|
| 2 |
from collections import defaultdict
|
|
|
|
|
|
|
| 3 |
from src.recall.itemcf import ItemCF
|
| 4 |
from src.recall.usercf import UserCF
|
| 5 |
from src.recall.popularity import PopularityRecall
|
|
@@ -10,8 +12,43 @@ from src.recall.sasrec_recall import SASRecRecall
|
|
| 10 |
|
| 11 |
logger = logging.getLogger(__name__)
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
class RecallFusion:
|
| 14 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
self.itemcf = ItemCF(data_dir, model_dir)
|
| 16 |
self.usercf = UserCF(data_dir, model_dir)
|
| 17 |
self.popularity = PopularityRecall(data_dir, model_dir)
|
|
@@ -21,8 +58,8 @@ class RecallFusion:
|
|
| 21 |
self.sasrec = SASRecRecall(data_dir, model_dir)
|
| 22 |
|
| 23 |
self.models_loaded = False
|
| 24 |
-
|
| 25 |
-
def load_models(self):
|
| 26 |
if self.models_loaded:
|
| 27 |
return
|
| 28 |
|
|
@@ -35,58 +72,57 @@ class RecallFusion:
|
|
| 35 |
self.item2vec.load()
|
| 36 |
self.sasrec.load()
|
| 37 |
self.models_loaded = True
|
| 38 |
-
|
| 39 |
-
def get_recall_items(self, user_id, history_items=None, k=100):
|
| 40 |
"""
|
| 41 |
-
Multi-channel recall fusion using RRF
|
| 42 |
"""
|
| 43 |
if not self.models_loaded:
|
| 44 |
self.load_models()
|
| 45 |
-
|
| 46 |
candidates = defaultdict(float)
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
sorted_cands = sorted(candidates.items(), key=lambda x: x[1], reverse=True)
|
| 78 |
return sorted_cands[:k]
|
| 79 |
-
|
| 80 |
-
def _add_to_candidates(self, candidates, recs, weight
|
| 81 |
"""
|
| 82 |
-
Add recommendations to candidate pool using RRF
|
| 83 |
-
score += weight * (1 / (
|
| 84 |
"""
|
| 85 |
if not recs:
|
| 86 |
return
|
| 87 |
-
|
| 88 |
for rank, (item, score) in enumerate(recs):
|
| 89 |
-
rrf_score = weight * (1.0 / (rrf_k + rank + 1))
|
| 90 |
candidates[item] += rrf_score
|
| 91 |
|
| 92 |
if __name__ == "__main__":
|
|
|
|
| 1 |
import logging
|
| 2 |
from collections import defaultdict
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
from src.recall.itemcf import ItemCF
|
| 6 |
from src.recall.usercf import UserCF
|
| 7 |
from src.recall.popularity import PopularityRecall
|
|
|
|
| 12 |
|
| 13 |
logger = logging.getLogger(__name__)
|
| 14 |
|
| 15 |
+
# Default: only the 3 most effective channels enabled. Others available but off.
|
| 16 |
+
DEFAULT_CHANNEL_CONFIG = {
|
| 17 |
+
"itemcf": {"enabled": True, "weight": 1.0},
|
| 18 |
+
"sasrec": {"enabled": True, "weight": 1.0},
|
| 19 |
+
"youtube_dnn": {"enabled": True, "weight": 1.0},
|
| 20 |
+
"usercf": {"enabled": False, "weight": 1.0},
|
| 21 |
+
"swing": {"enabled": False, "weight": 1.0},
|
| 22 |
+
"item2vec": {"enabled": False, "weight": 0.8},
|
| 23 |
+
"popularity": {"enabled": False, "weight": 0.5},
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _merge_config(default: dict, override: Optional[dict]) -> dict:
|
| 28 |
+
"""Deep-merge override into default (shallow per channel)."""
|
| 29 |
+
merged = {k: dict(v) for k, v in default.items()}
|
| 30 |
+
if override:
|
| 31 |
+
for ch, cfg in override.items():
|
| 32 |
+
if ch in merged:
|
| 33 |
+
merged[ch].update(cfg)
|
| 34 |
+
else:
|
| 35 |
+
merged[ch] = dict(cfg)
|
| 36 |
+
return merged
|
| 37 |
+
|
| 38 |
+
|
| 39 |
class RecallFusion:
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
data_dir: str = "data/rec",
|
| 43 |
+
model_dir: str = "data/model/recall",
|
| 44 |
+
channel_config: Optional[dict] = None,
|
| 45 |
+
rrf_k: int = 60,
|
| 46 |
+
):
|
| 47 |
+
self.data_dir = data_dir
|
| 48 |
+
self.model_dir = model_dir
|
| 49 |
+
self.channel_config = _merge_config(DEFAULT_CHANNEL_CONFIG, channel_config)
|
| 50 |
+
self.rrf_k = rrf_k
|
| 51 |
+
|
| 52 |
self.itemcf = ItemCF(data_dir, model_dir)
|
| 53 |
self.usercf = UserCF(data_dir, model_dir)
|
| 54 |
self.popularity = PopularityRecall(data_dir, model_dir)
|
|
|
|
| 58 |
self.sasrec = SASRecRecall(data_dir, model_dir)
|
| 59 |
|
| 60 |
self.models_loaded = False
|
| 61 |
+
|
| 62 |
+
def load_models(self) -> None:
|
| 63 |
if self.models_loaded:
|
| 64 |
return
|
| 65 |
|
|
|
|
| 72 |
self.item2vec.load()
|
| 73 |
self.sasrec.load()
|
| 74 |
self.models_loaded = True
|
| 75 |
+
|
| 76 |
+
def get_recall_items(self, user_id: str, history_items=None, k: int = 100):
|
| 77 |
"""
|
| 78 |
+
Multi-channel recall fusion using RRF. Channels and weights controlled by config.
|
| 79 |
"""
|
| 80 |
if not self.models_loaded:
|
| 81 |
self.load_models()
|
| 82 |
+
|
| 83 |
candidates = defaultdict(float)
|
| 84 |
+
cfg = self.channel_config
|
| 85 |
+
|
| 86 |
+
if cfg.get("youtube_dnn", {}).get("enabled", False):
|
| 87 |
+
recs = self.youtube_dnn.recommend(user_id, history_items, top_k=k)
|
| 88 |
+
self._add_to_candidates(candidates, recs, cfg["youtube_dnn"]["weight"])
|
| 89 |
+
|
| 90 |
+
if cfg.get("itemcf", {}).get("enabled", False):
|
| 91 |
+
recs = self.itemcf.recommend(user_id, history_items, top_k=k)
|
| 92 |
+
self._add_to_candidates(candidates, recs, cfg["itemcf"]["weight"])
|
| 93 |
+
|
| 94 |
+
if cfg.get("usercf", {}).get("enabled", False):
|
| 95 |
+
recs = self.usercf.recommend(user_id, history_items, top_k=k)
|
| 96 |
+
self._add_to_candidates(candidates, recs, cfg["usercf"]["weight"])
|
| 97 |
+
|
| 98 |
+
if cfg.get("swing", {}).get("enabled", False):
|
| 99 |
+
recs = self.swing.recommend(user_id, history_items, top_k=k)
|
| 100 |
+
self._add_to_candidates(candidates, recs, cfg["swing"]["weight"])
|
| 101 |
+
|
| 102 |
+
if cfg.get("sasrec", {}).get("enabled", False):
|
| 103 |
+
recs = self.sasrec.recommend(user_id, history_items, top_k=k)
|
| 104 |
+
self._add_to_candidates(candidates, recs, cfg["sasrec"]["weight"])
|
| 105 |
+
|
| 106 |
+
if cfg.get("item2vec", {}).get("enabled", False):
|
| 107 |
+
recs = self.item2vec.recommend(user_id, history_items, top_k=k)
|
| 108 |
+
self._add_to_candidates(candidates, recs, cfg["item2vec"]["weight"])
|
| 109 |
+
|
| 110 |
+
if cfg.get("popularity", {}).get("enabled", False):
|
| 111 |
+
recs = self.popularity.recommend(user_id, top_k=k)
|
| 112 |
+
self._add_to_candidates(candidates, recs, cfg["popularity"]["weight"])
|
| 113 |
+
|
| 114 |
sorted_cands = sorted(candidates.items(), key=lambda x: x[1], reverse=True)
|
| 115 |
return sorted_cands[:k]
|
| 116 |
+
|
| 117 |
+
def _add_to_candidates(self, candidates, recs, weight: float) -> None:
|
| 118 |
"""
|
| 119 |
+
Add recommendations to candidate pool using RRF.
|
| 120 |
+
score += weight * (1 / (rrf_k + rank + 1))
|
| 121 |
"""
|
| 122 |
if not recs:
|
| 123 |
return
|
|
|
|
| 124 |
for rank, (item, score) in enumerate(recs):
|
| 125 |
+
rrf_score = weight * (1.0 / (self.rrf_k + rank + 1))
|
| 126 |
candidates[item] += rrf_score
|
| 127 |
|
| 128 |
if __name__ == "__main__":
|
src/recall/itemcf.py
CHANGED
|
@@ -1,37 +1,150 @@
|
|
| 1 |
-
import pickle
|
| 2 |
import math
|
| 3 |
-
import
|
| 4 |
-
import pandas as pd
|
| 5 |
-
from tqdm import tqdm
|
| 6 |
from collections import defaultdict
|
| 7 |
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
import logging
|
| 9 |
|
| 10 |
logger = logging.getLogger(__name__)
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
class ItemCF:
|
| 13 |
"""
|
| 14 |
Item-based Collaborative Filtering.
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
This change ensures zero-RAM loading for the similarity matrix while maintaining
|
| 22 |
-
100% mathematical parity with the original Python implementation.
|
| 23 |
"""
|
| 24 |
-
|
|
|
|
| 25 |
self.data_dir = Path(data_dir)
|
| 26 |
self.save_dir = Path(save_dir)
|
| 27 |
-
self.
|
| 28 |
-
self.
|
| 29 |
-
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
if self.db_path.exists():
|
| 32 |
-
import sqlite3
|
| 33 |
try:
|
| 34 |
-
self.conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
| 35 |
logger.info(f"ItemCF: Connected to SQLite {self.db_path}")
|
| 36 |
return True
|
| 37 |
except Exception as e:
|
|
@@ -86,8 +199,6 @@ class ItemCF:
|
|
| 86 |
logger.error(f"ItemCF Query Error: {e}")
|
| 87 |
return []
|
| 88 |
|
| 89 |
-
def save(self): pass # Migration is done via script
|
| 90 |
-
def fit(self, df): pass # Training should be done separately
|
| 91 |
|
| 92 |
if __name__ == "__main__":
|
| 93 |
# Test run
|
|
|
|
|
|
|
| 1 |
import math
|
| 2 |
+
import sqlite3
|
|
|
|
|
|
|
| 3 |
from collections import defaultdict
|
| 4 |
from pathlib import Path
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
import pandas as pd
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
|
| 10 |
import logging
|
| 11 |
|
| 12 |
logger = logging.getLogger(__name__)
|
| 13 |
|
| 14 |
+
# Direction weights for asymmetric co-occurrence (CHANGELOG: forward=1.0, backward=0.7)
|
| 15 |
+
FORWARD_WEIGHT = 1.0
|
| 16 |
+
BACKWARD_WEIGHT = 0.7
|
| 17 |
+
|
| 18 |
+
|
| 19 |
class ItemCF:
|
| 20 |
"""
|
| 21 |
Item-based Collaborative Filtering.
|
| 22 |
+
|
| 23 |
+
Co-occurrence similarity with direction weight: when user reads item A then B,
|
| 24 |
+
sim(A,B) += 1.0 (forward), sim(B,A) += 0.7 (backward). This captures temporal
|
| 25 |
+
"read-after" patterns.
|
| 26 |
+
|
| 27 |
+
Persists to SQLite (recall_models.db) for zero-RAM inference.
|
|
|
|
|
|
|
| 28 |
"""
|
| 29 |
+
|
| 30 |
+
def __init__(self, data_dir: str = "data/rec", save_dir: str = "data/model/recall"):
|
| 31 |
self.data_dir = Path(data_dir)
|
| 32 |
self.save_dir = Path(save_dir)
|
| 33 |
+
self.save_dir.mkdir(parents=True, exist_ok=True)
|
| 34 |
+
self.db_path = self.data_dir.parent / "recall_models.db"
|
| 35 |
+
self.conn: Optional[sqlite3.Connection] = None
|
| 36 |
+
|
| 37 |
+
def fit(self, df: pd.DataFrame, top_k_sim: int = 200) -> "ItemCF":
|
| 38 |
+
"""
|
| 39 |
+
Build co-occurrence similarity matrix with direction weight, then persist to SQLite.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
df: DataFrame with columns [user_id, isbn, rating, timestamp].
|
| 43 |
+
If timestamp is missing, assumes row order per user.
|
| 44 |
+
top_k_sim: Keep only top-k similar items per item to reduce size.
|
| 45 |
+
"""
|
| 46 |
+
logger.info("Building ItemCF similarity matrix (direction-weighted co-occurrence)...")
|
| 47 |
+
|
| 48 |
+
# 1. Build per-user chronologically ordered item sequences
|
| 49 |
+
user_seqs: dict[str, list[tuple[str, float]]] = defaultdict(list)
|
| 50 |
+
if "timestamp" in df.columns:
|
| 51 |
+
for _, row in tqdm(df.iterrows(), total=len(df), desc="Building user sequences"):
|
| 52 |
+
user_seqs[row["user_id"]].append((str(row["isbn"]), float(row["timestamp"])))
|
| 53 |
+
for uid in user_seqs:
|
| 54 |
+
user_seqs[uid] = [x[0] for x in sorted(user_seqs[uid], key=lambda t: t[1])]
|
| 55 |
+
else:
|
| 56 |
+
for _, row in tqdm(df.iterrows(), total=len(df), desc="Building user sequences"):
|
| 57 |
+
user_seqs[row["user_id"]].append(str(row["isbn"]))
|
| 58 |
+
|
| 59 |
+
user_hist = {u: list(items) for u, items in user_seqs.items()}
|
| 60 |
+
|
| 61 |
+
# 2. Count users per item (for cosine normalization)
|
| 62 |
+
item_users: dict[str, set[str]] = defaultdict(set)
|
| 63 |
+
for user_id, items in user_hist.items():
|
| 64 |
+
for item in items:
|
| 65 |
+
item_users[item].add(user_id)
|
| 66 |
+
item_counts = {k: len(v) for k, v in item_users.items()}
|
| 67 |
+
|
| 68 |
+
# 3. Build item-item co-occurrence with direction weight
|
| 69 |
+
sim: dict[str, dict[str, float]] = defaultdict(lambda: defaultdict(float))
|
| 70 |
+
for user_id, items in tqdm(user_hist.items(), desc="Computing co-occurrence"):
|
| 71 |
+
for i in range(len(items)):
|
| 72 |
+
item_i = items[i]
|
| 73 |
+
for j in range(i + 1, len(items)):
|
| 74 |
+
item_j = items[j]
|
| 75 |
+
# Forward: i before j -> sim(i,j) += 1.0
|
| 76 |
+
sim[item_i][item_j] += FORWARD_WEIGHT
|
| 77 |
+
# Backward: j after i -> sim(j,i) += 0.7
|
| 78 |
+
sim[item_j][item_i] += BACKWARD_WEIGHT
|
| 79 |
+
|
| 80 |
+
# 4. Normalize by sqrt(|N_i| * |N_j|) (cosine-style)
|
| 81 |
+
logger.info("Normalizing ItemCF matrix...")
|
| 82 |
+
final_sim: dict[str, dict[str, float]] = {}
|
| 83 |
+
for item_i, related in tqdm(sim.items(), desc="Normalizing"):
|
| 84 |
+
ni = item_counts.get(item_i, 1)
|
| 85 |
+
pruned = sorted(related.items(), key=lambda x: x[1], reverse=True)[:top_k_sim]
|
| 86 |
+
final_sim[item_i] = {}
|
| 87 |
+
for item_j, raw_score in pruned:
|
| 88 |
+
nj = item_counts.get(item_j, 1)
|
| 89 |
+
norm = math.sqrt(ni * nj)
|
| 90 |
+
if norm > 0:
|
| 91 |
+
final_sim[item_i][item_j] = raw_score / norm
|
| 92 |
+
|
| 93 |
+
self._sim_matrix = final_sim
|
| 94 |
+
self._user_hist = user_hist
|
| 95 |
+
self.save()
|
| 96 |
+
logger.info(f"ItemCF built: {len(final_sim)} items, saved to {self.db_path}")
|
| 97 |
+
return self
|
| 98 |
+
|
| 99 |
+
def save(self) -> None:
|
| 100 |
+
"""Persist similarity matrix and user history to SQLite."""
|
| 101 |
+
if not hasattr(self, "_sim_matrix") or not hasattr(self, "_user_hist"):
|
| 102 |
+
logger.warning("ItemCF.save: No fitted model to save.")
|
| 103 |
+
return
|
| 104 |
+
|
| 105 |
+
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
| 106 |
+
conn = sqlite3.connect(str(self.db_path))
|
| 107 |
+
cursor = conn.cursor()
|
| 108 |
+
|
| 109 |
+
cursor.execute("DROP TABLE IF EXISTS item_similarity")
|
| 110 |
+
cursor.execute("""
|
| 111 |
+
CREATE TABLE item_similarity (item1 TEXT, item2 TEXT, score REAL)
|
| 112 |
+
""")
|
| 113 |
+
cursor.execute("DROP TABLE IF EXISTS user_history")
|
| 114 |
+
cursor.execute("""
|
| 115 |
+
CREATE TABLE user_history (user_id TEXT, isbn TEXT)
|
| 116 |
+
""")
|
| 117 |
+
|
| 118 |
+
batch = []
|
| 119 |
+
for item1, related in tqdm(self._sim_matrix.items(), desc="Writing item_similarity"):
|
| 120 |
+
for item2, score in related.items():
|
| 121 |
+
batch.append((item1, item2, score))
|
| 122 |
+
if len(batch) >= 100000:
|
| 123 |
+
cursor.executemany("INSERT INTO item_similarity VALUES (?, ?, ?)", batch)
|
| 124 |
+
batch = []
|
| 125 |
+
if batch:
|
| 126 |
+
cursor.executemany("INSERT INTO item_similarity VALUES (?, ?, ?)", batch)
|
| 127 |
+
|
| 128 |
+
batch = []
|
| 129 |
+
for user_id, isbns in tqdm(self._user_hist.items(), desc="Writing user_history"):
|
| 130 |
+
for isbn in isbns:
|
| 131 |
+
batch.append((user_id, isbn))
|
| 132 |
+
if len(batch) >= 100000:
|
| 133 |
+
cursor.executemany("INSERT INTO user_history VALUES (?, ?)", batch)
|
| 134 |
+
batch = []
|
| 135 |
+
if batch:
|
| 136 |
+
cursor.executemany("INSERT INTO user_history VALUES (?, ?)", batch)
|
| 137 |
+
|
| 138 |
+
cursor.execute("CREATE INDEX IF NOT EXISTS idx_item1 ON item_similarity(item1)")
|
| 139 |
+
cursor.execute("CREATE INDEX IF NOT EXISTS idx_user ON user_history(user_id)")
|
| 140 |
+
conn.commit()
|
| 141 |
+
conn.close()
|
| 142 |
+
logger.info(f"ItemCF saved to {self.db_path}")
|
| 143 |
+
|
| 144 |
+
def load(self) -> bool:
|
| 145 |
if self.db_path.exists():
|
|
|
|
| 146 |
try:
|
| 147 |
+
self.conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
|
| 148 |
logger.info(f"ItemCF: Connected to SQLite {self.db_path}")
|
| 149 |
return True
|
| 150 |
except Exception as e:
|
|
|
|
| 199 |
logger.error(f"ItemCF Query Error: {e}")
|
| 200 |
return []
|
| 201 |
|
|
|
|
|
|
|
| 202 |
|
| 203 |
if __name__ == "__main__":
|
| 204 |
# Test run
|
src/recall/popularity.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
import pandas as pd
|
| 2 |
-
from collections import defaultdict
|
| 3 |
import pickle
|
| 4 |
from pathlib import Path
|
| 5 |
import logging
|
|
|
|
| 1 |
import pandas as pd
|
|
|
|
| 2 |
import pickle
|
| 3 |
from pathlib import Path
|
| 4 |
import logging
|
src/recall/sasrec_recall.py
CHANGED
|
@@ -10,13 +10,52 @@ for SIMD-accelerated approximate nearest neighbor search.
|
|
| 10 |
|
| 11 |
import pickle
|
| 12 |
import logging
|
| 13 |
-
import numpy as np
|
| 14 |
-
import faiss
|
| 15 |
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
logger = logging.getLogger(__name__)
|
| 18 |
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
class SASRecRecall:
|
| 21 |
def __init__(self, data_dir='data/rec', model_dir='data/model/recall'):
|
| 22 |
self.data_dir = Path(data_dir)
|
|
@@ -30,7 +69,123 @@ class SASRecRecall:
|
|
| 30 |
self.faiss_index = None # Faiss IndexFlatIP for fast inner-product search
|
| 31 |
self.loaded = False
|
| 32 |
|
| 33 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
try:
|
| 35 |
logger.info("Loading SASRec recall embeddings...")
|
| 36 |
|
|
|
|
| 10 |
|
| 11 |
import pickle
|
| 12 |
import logging
|
|
|
|
|
|
|
| 13 |
from pathlib import Path
|
| 14 |
+
from typing import Optional
|
| 15 |
+
|
| 16 |
+
import faiss
|
| 17 |
+
import numpy as np
|
| 18 |
+
import pandas as pd
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
import torch.optim as optim
|
| 22 |
+
from torch.utils.data import Dataset, DataLoader
|
| 23 |
+
from tqdm import tqdm
|
| 24 |
+
|
| 25 |
+
from src.model.sasrec import SASRec
|
| 26 |
+
from src.recall.sequence_utils import build_sequences_from_df
|
| 27 |
|
| 28 |
logger = logging.getLogger(__name__)
|
| 29 |
|
| 30 |
|
| 31 |
+
class _SeqDataset(Dataset):
|
| 32 |
+
"""Internal dataset for SASRec training (seq, pos, neg)."""
|
| 33 |
+
|
| 34 |
+
def __init__(self, seqs_dict: dict, num_items: int, max_len: int):
|
| 35 |
+
self.seqs: list[list[int]] = []
|
| 36 |
+
self.num_items = num_items
|
| 37 |
+
self.max_len = max_len
|
| 38 |
+
|
| 39 |
+
for seq in seqs_dict.values():
|
| 40 |
+
if len(seq) < 2:
|
| 41 |
+
continue
|
| 42 |
+
padded = [0] * max_len
|
| 43 |
+
seq_len = min(len(seq), max_len)
|
| 44 |
+
padded[-seq_len:] = seq[-seq_len:]
|
| 45 |
+
self.seqs.append(padded)
|
| 46 |
+
self.seqs = torch.LongTensor(self.seqs)
|
| 47 |
+
|
| 48 |
+
def __len__(self) -> int:
|
| 49 |
+
return len(self.seqs)
|
| 50 |
+
|
| 51 |
+
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 52 |
+
seq = self.seqs[idx]
|
| 53 |
+
pos = np.zeros_like(seq.numpy())
|
| 54 |
+
pos[:-1] = seq.numpy()[1:]
|
| 55 |
+
neg = np.random.randint(1, self.num_items + 1, size=len(seq))
|
| 56 |
+
return seq, torch.LongTensor(pos), torch.LongTensor(neg)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
class SASRecRecall:
|
| 60 |
def __init__(self, data_dir='data/rec', model_dir='data/model/recall'):
|
| 61 |
self.data_dir = Path(data_dir)
|
|
|
|
| 69 |
self.faiss_index = None # Faiss IndexFlatIP for fast inner-product search
|
| 70 |
self.loaded = False
|
| 71 |
|
| 72 |
+
def fit(
|
| 73 |
+
self,
|
| 74 |
+
df: pd.DataFrame,
|
| 75 |
+
max_len: int = 50,
|
| 76 |
+
hidden_dim: int = 64,
|
| 77 |
+
epochs: int = 30,
|
| 78 |
+
batch_size: int = 128,
|
| 79 |
+
lr: float = 1e-4,
|
| 80 |
+
) -> "SASRecRecall":
|
| 81 |
+
"""
|
| 82 |
+
Train SASRec from interaction DataFrame. Builds sequences internally.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
df: [user_id, isbn, timestamp] (timestamp optional)
|
| 86 |
+
max_len, hidden_dim, epochs, batch_size, lr: Training hyperparameters.
|
| 87 |
+
"""
|
| 88 |
+
logger.info("Building sequences from DataFrame...")
|
| 89 |
+
user_seqs, item_map = build_sequences_from_df(df, max_len=max_len)
|
| 90 |
+
|
| 91 |
+
self.item_map = item_map
|
| 92 |
+
self.id_to_item = {v: k for k, v in item_map.items()}
|
| 93 |
+
num_items = len(item_map)
|
| 94 |
+
|
| 95 |
+
if torch.cuda.is_available():
|
| 96 |
+
device = torch.device("cuda")
|
| 97 |
+
elif torch.backends.mps.is_available():
|
| 98 |
+
device = torch.device("mps")
|
| 99 |
+
else:
|
| 100 |
+
device = torch.device("cpu")
|
| 101 |
+
|
| 102 |
+
logger.info(f"Training SASRec on {device}...")
|
| 103 |
+
dataset = _SeqDataset(user_seqs, num_items, max_len)
|
| 104 |
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
| 105 |
+
|
| 106 |
+
model = SASRec(num_items, max_len, hidden_dim).to(device)
|
| 107 |
+
optimizer = optim.Adam(model.parameters(), lr=lr)
|
| 108 |
+
criterion = nn.BCEWithLogitsLoss()
|
| 109 |
+
|
| 110 |
+
model.train()
|
| 111 |
+
for epoch in range(epochs):
|
| 112 |
+
total_loss = 0.0
|
| 113 |
+
pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")
|
| 114 |
+
for seq, pos, neg in pbar:
|
| 115 |
+
seq, pos, neg = seq.to(device), pos.to(device), neg.to(device)
|
| 116 |
+
seq_emb = model(seq)
|
| 117 |
+
mask = pos != 0
|
| 118 |
+
pos_emb = model.item_emb(pos)
|
| 119 |
+
neg_emb = model.item_emb(neg)
|
| 120 |
+
pos_logits = (seq_emb * pos_emb).sum(dim=-1)[mask]
|
| 121 |
+
neg_logits = (seq_emb * neg_emb).sum(dim=-1)[mask]
|
| 122 |
+
pos_labels = torch.ones_like(pos_logits)
|
| 123 |
+
neg_labels = torch.zeros_like(neg_logits)
|
| 124 |
+
loss = criterion(pos_logits, pos_labels) + criterion(neg_logits, neg_labels)
|
| 125 |
+
optimizer.zero_grad()
|
| 126 |
+
loss.backward()
|
| 127 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 128 |
+
optimizer.step()
|
| 129 |
+
total_loss += loss.item()
|
| 130 |
+
pbar.set_postfix(loss=total_loss / (pbar.n + 1))
|
| 131 |
+
|
| 132 |
+
# Save model
|
| 133 |
+
sasrec_dir = self.model_dir.parent / "rec"
|
| 134 |
+
sasrec_dir.mkdir(parents=True, exist_ok=True)
|
| 135 |
+
torch.save(model.state_dict(), sasrec_dir / "sasrec_model.pth")
|
| 136 |
+
|
| 137 |
+
# Extract user embeddings
|
| 138 |
+
logger.info("Extracting user sequence embeddings...")
|
| 139 |
+
model.eval()
|
| 140 |
+
user_emb_dict: dict = {}
|
| 141 |
+
all_users = list(user_seqs.keys())
|
| 142 |
+
|
| 143 |
+
with torch.no_grad():
|
| 144 |
+
for i in tqdm(range(0, len(all_users), batch_size), desc="Embedding users"):
|
| 145 |
+
batch_users = all_users[i : i + batch_size]
|
| 146 |
+
batch_seqs = []
|
| 147 |
+
for u in batch_users:
|
| 148 |
+
s = user_seqs[u]
|
| 149 |
+
padded = [0] * max_len
|
| 150 |
+
seq_len = min(len(s), max_len)
|
| 151 |
+
if seq_len > 0:
|
| 152 |
+
padded[-seq_len:] = s[-seq_len:]
|
| 153 |
+
batch_seqs.append(padded)
|
| 154 |
+
input_tensor = torch.LongTensor(batch_seqs).to(device)
|
| 155 |
+
output = model(input_tensor)
|
| 156 |
+
last_state = output[:, -1, :].cpu().numpy()
|
| 157 |
+
for j, u in enumerate(batch_users):
|
| 158 |
+
user_emb_dict[u] = last_state[j]
|
| 159 |
+
|
| 160 |
+
self.data_dir.mkdir(parents=True, exist_ok=True)
|
| 161 |
+
with open(self.data_dir / "user_seq_emb.pkl", "wb") as f:
|
| 162 |
+
pickle.dump(user_emb_dict, f)
|
| 163 |
+
with open(self.data_dir / "item_map.pkl", "wb") as f:
|
| 164 |
+
pickle.dump(self.item_map, f)
|
| 165 |
+
with open(self.data_dir / "user_sequences.pkl", "wb") as f:
|
| 166 |
+
pickle.dump(user_seqs, f)
|
| 167 |
+
|
| 168 |
+
self.user_seq_emb = user_emb_dict
|
| 169 |
+
self.user_hist = {
|
| 170 |
+
u: set(self.id_to_item[idx] for idx in seq if idx in self.id_to_item)
|
| 171 |
+
for u, seq in user_seqs.items()
|
| 172 |
+
}
|
| 173 |
+
self.item_emb = model.item_emb.weight.detach().cpu().numpy()
|
| 174 |
+
self._build_faiss_index()
|
| 175 |
+
self.loaded = True
|
| 176 |
+
|
| 177 |
+
logger.info(f"SASRec saved to {sasrec_dir}")
|
| 178 |
+
return self
|
| 179 |
+
|
| 180 |
+
def _build_faiss_index(self) -> None:
|
| 181 |
+
"""Build Faiss index from item embeddings."""
|
| 182 |
+
if self.item_emb is None:
|
| 183 |
+
return
|
| 184 |
+
dim = self.item_emb.shape[1]
|
| 185 |
+
self.faiss_index = faiss.IndexFlatIP(dim)
|
| 186 |
+
self.faiss_index.add(np.ascontiguousarray(self.item_emb.astype(np.float32)))
|
| 187 |
+
|
| 188 |
+
def load(self) -> bool:
|
| 189 |
try:
|
| 190 |
logger.info("Loading SASRec recall embeddings...")
|
| 191 |
|
src/recall/sequence_utils.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Shared utilities for building user sequences from interaction DataFrames.
|
| 3 |
+
Used by SASRec and YoutubeDNN training.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from typing import Tuple
|
| 7 |
+
|
| 8 |
+
import pandas as pd
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def build_sequences_from_df(
|
| 13 |
+
df: pd.DataFrame, max_len: int = 50
|
| 14 |
+
) -> Tuple[dict[str, list[int]], dict[str, int]]:
|
| 15 |
+
"""
|
| 16 |
+
Build user sequences and item map from interaction DataFrame.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
df: DataFrame with columns [user_id, isbn] and optionally [timestamp].
|
| 20 |
+
max_len: Maximum sequence length (truncate from the left).
|
| 21 |
+
|
| 22 |
+
Returns:
|
| 23 |
+
user_seqs: Dict[user_id, list of item_ids] (1-indexed, 0 is padding)
|
| 24 |
+
item_map: Dict[isbn, item_id]
|
| 25 |
+
"""
|
| 26 |
+
items = df["isbn"].astype(str).unique()
|
| 27 |
+
item_map = {isbn: i + 1 for i, isbn in enumerate(items)}
|
| 28 |
+
|
| 29 |
+
user_history: dict[str, list[tuple[str, float]]] = {}
|
| 30 |
+
has_ts = "timestamp" in df.columns
|
| 31 |
+
|
| 32 |
+
for _, row in tqdm(df.iterrows(), total=len(df), desc="Building sequences"):
|
| 33 |
+
u = str(row["user_id"])
|
| 34 |
+
isbn = str(row["isbn"])
|
| 35 |
+
ts = float(row["timestamp"]) if has_ts else 0.0
|
| 36 |
+
if u not in user_history:
|
| 37 |
+
user_history[u] = []
|
| 38 |
+
user_history[u].append((isbn, ts))
|
| 39 |
+
|
| 40 |
+
user_seqs: dict[str, list[int]] = {}
|
| 41 |
+
for u, pairs in user_history.items():
|
| 42 |
+
if has_ts:
|
| 43 |
+
pairs.sort(key=lambda x: x[1])
|
| 44 |
+
item_ids = [item_map.get(isbn, 0) for isbn, _ in pairs]
|
| 45 |
+
item_ids = [x for x in item_ids if x != 0]
|
| 46 |
+
user_seqs[u] = item_ids[-max_len:]
|
| 47 |
+
|
| 48 |
+
return user_seqs, item_map
|
src/recommender.py
CHANGED
|
@@ -1,13 +1,9 @@
|
|
| 1 |
-
import pandas as pd
|
| 2 |
from typing import List, Dict, Any
|
| 3 |
-
from src.etl import load_books_data
|
| 4 |
from src.vector_db import VectorDB
|
| 5 |
from src.config import TOP_K_INITIAL, TOP_K_FINAL, DATA_DIR
|
| 6 |
from src.cache import CacheManager
|
| 7 |
|
| 8 |
-
from src.utils import setup_logger
|
| 9 |
-
from src.cover_fetcher import fetch_book_cover
|
| 10 |
-
from src.marketing.personalized_highlight import get_persona_and_highlights
|
| 11 |
from src.core.metadata_store import metadata_store
|
| 12 |
|
| 13 |
logger = setup_logger(__name__)
|
|
|
|
|
|
|
| 1 |
from typing import List, Dict, Any
|
|
|
|
| 2 |
from src.vector_db import VectorDB
|
| 3 |
from src.config import TOP_K_INITIAL, TOP_K_FINAL, DATA_DIR
|
| 4 |
from src.cache import CacheManager
|
| 5 |
|
| 6 |
+
from src.utils import setup_logger
|
|
|
|
|
|
|
| 7 |
from src.core.metadata_store import metadata_store
|
| 8 |
|
| 9 |
logger = setup_logger(__name__)
|
src/services/chat_service.py
CHANGED
|
@@ -1,22 +1,22 @@
|
|
| 1 |
from typing import Generator, Optional, Dict, Any, List
|
| 2 |
-
import pandas as pd
|
| 3 |
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, BaseMessage
|
| 4 |
|
| 5 |
from src.core.llm import LLMFactory
|
| 6 |
-
from src.
|
| 7 |
from src.marketing.persona import build_persona
|
| 8 |
from src.user.profile_store import list_favorites
|
| 9 |
from src.utils import setup_logger
|
| 10 |
|
| 11 |
logger = setup_logger(__name__)
|
| 12 |
|
|
|
|
| 13 |
class ChatService:
|
| 14 |
"""
|
| 15 |
Service for RAG-based chat interaction.
|
| 16 |
Currently focused on 'Chat with Book' (Single Item Context).
|
|
|
|
| 17 |
"""
|
| 18 |
_instance = None
|
| 19 |
-
_books_df = None
|
| 20 |
_history: Dict[str, List[BaseMessage]] = {}
|
| 21 |
|
| 22 |
def __new__(cls):
|
|
@@ -25,25 +25,11 @@ class ChatService:
|
|
| 25 |
return cls._instance
|
| 26 |
|
| 27 |
def __init__(self):
|
| 28 |
-
# Data is now loaded lazily via _ensure_data
|
| 29 |
pass
|
| 30 |
|
| 31 |
-
def _ensure_data(self):
|
| 32 |
-
if self._books_df is None:
|
| 33 |
-
logger.info("ChatService: Lazy-loading books data for context retrieval...")
|
| 34 |
-
self._books_df = load_books_data()
|
| 35 |
-
|
| 36 |
def _get_book_context(self, isbn: str) -> Optional[Dict[str, Any]]:
|
| 37 |
-
"""Retrieve full context for a specific book by ISBN."""
|
| 38 |
-
|
| 39 |
-
# Handle string/int types for ISBN
|
| 40 |
-
try:
|
| 41 |
-
row = self._books_df[self._books_df["isbn13"].astype(str) == str(isbn)]
|
| 42 |
-
if row.empty:
|
| 43 |
-
return None
|
| 44 |
-
return row.iloc[0].to_dict()
|
| 45 |
-
except Exception:
|
| 46 |
-
return None
|
| 47 |
|
| 48 |
def _format_book_info(self, book: Dict[str, Any]) -> str:
|
| 49 |
"""Format book metadata into a readable context string."""
|
|
@@ -97,8 +83,7 @@ class ChatService:
|
|
| 97 |
"""
|
| 98 |
Stream chat response for a specific book.
|
| 99 |
"""
|
| 100 |
-
|
| 101 |
-
# 1. Fetch Context
|
| 102 |
book = self._get_book_context(isbn)
|
| 103 |
if not book:
|
| 104 |
yield "I'm sorry, I couldn't find the details for this book."
|
|
@@ -106,7 +91,7 @@ class ChatService:
|
|
| 106 |
|
| 107 |
# 2. Build Persona (User Profile)
|
| 108 |
favs = list_favorites(user_id)
|
| 109 |
-
persona_data = build_persona(favs
|
| 110 |
user_persona = persona_data.get("summary", "General Reader")
|
| 111 |
|
| 112 |
# 3. Construct Prompt with History
|
|
@@ -158,15 +143,11 @@ class ChatService:
|
|
| 158 |
yield f"Error generating response: {str(e)}. Please check your API Key."
|
| 159 |
|
| 160 |
def add_book_to_context(self, book_data: Dict[str, Any]):
|
| 161 |
-
"""
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
self._books_df = pd.concat([self._books_df, new_row_df], ignore_index=True)
|
| 167 |
-
logger.info(f"ChatService: Added book {book_data.get('isbn13')} to context.")
|
| 168 |
-
except Exception as e:
|
| 169 |
-
logger.error(f"ChatService: Failed to add book to context: {e}")
|
| 170 |
|
| 171 |
def get_chat_service():
|
| 172 |
"""Helper for lazy access to the ChatService singleton."""
|
|
|
|
| 1 |
from typing import Generator, Optional, Dict, Any, List
|
|
|
|
| 2 |
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, BaseMessage
|
| 3 |
|
| 4 |
from src.core.llm import LLMFactory
|
| 5 |
+
from src.data.repository import data_repository
|
| 6 |
from src.marketing.persona import build_persona
|
| 7 |
from src.user.profile_store import list_favorites
|
| 8 |
from src.utils import setup_logger
|
| 9 |
|
| 10 |
logger = setup_logger(__name__)
|
| 11 |
|
| 12 |
+
|
| 13 |
class ChatService:
|
| 14 |
"""
|
| 15 |
Service for RAG-based chat interaction.
|
| 16 |
Currently focused on 'Chat with Book' (Single Item Context).
|
| 17 |
+
Uses DataRepository for all book metadata lookups.
|
| 18 |
"""
|
| 19 |
_instance = None
|
|
|
|
| 20 |
_history: Dict[str, List[BaseMessage]] = {}
|
| 21 |
|
| 22 |
def __new__(cls):
|
|
|
|
| 25 |
return cls._instance
|
| 26 |
|
| 27 |
def __init__(self):
|
|
|
|
| 28 |
pass
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
def _get_book_context(self, isbn: str) -> Optional[Dict[str, Any]]:
|
| 31 |
+
"""Retrieve full context for a specific book by ISBN via DataRepository."""
|
| 32 |
+
return data_repository.get_book_metadata(str(isbn))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
def _format_book_info(self, book: Dict[str, Any]) -> str:
|
| 35 |
"""Format book metadata into a readable context string."""
|
|
|
|
| 83 |
"""
|
| 84 |
Stream chat response for a specific book.
|
| 85 |
"""
|
| 86 |
+
# 1. Fetch Context via DataRepository
|
|
|
|
| 87 |
book = self._get_book_context(isbn)
|
| 88 |
if not book:
|
| 89 |
yield "I'm sorry, I couldn't find the details for this book."
|
|
|
|
| 91 |
|
| 92 |
# 2. Build Persona (User Profile)
|
| 93 |
favs = list_favorites(user_id)
|
| 94 |
+
persona_data = build_persona(favs)
|
| 95 |
user_persona = persona_data.get("summary", "General Reader")
|
| 96 |
|
| 97 |
# 3. Construct Prompt with History
|
|
|
|
| 143 |
yield f"Error generating response: {str(e)}. Please check your API Key."
|
| 144 |
|
| 145 |
def add_book_to_context(self, book_data: Dict[str, Any]):
|
| 146 |
+
"""
|
| 147 |
+
Called when a new book is added to the system. Book is already in MetadataStore
|
| 148 |
+
via recommender.add_new_book, so no in-memory cache to update. No-op for now.
|
| 149 |
+
"""
|
| 150 |
+
logger.info(f"ChatService: Book {book_data.get('isbn13')} added; context served from MetadataStore.")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
|
| 152 |
def get_chat_service():
|
| 153 |
"""Helper for lazy access to the ChatService singleton."""
|
src/vector_db.py
CHANGED
|
@@ -1,10 +1,7 @@
|
|
| 1 |
-
import gc
|
| 2 |
from typing import List, Any
|
| 3 |
# Using community version to avoid 'BaseBlobParser' version conflict in langchain-chroma/core
|
| 4 |
from langchain_community.vectorstores import Chroma
|
| 5 |
from langchain_huggingface import HuggingFaceEmbeddings
|
| 6 |
-
from langchain_community.document_loaders import TextLoader
|
| 7 |
-
from langchain_text_splitters import CharacterTextSplitter
|
| 8 |
from src.config import REVIEW_HIGHLIGHTS_TXT, CHROMA_DB_DIR, EMBEDDING_MODEL
|
| 9 |
from src.utils import setup_logger
|
| 10 |
from src.core.metadata_store import metadata_store
|
|
|
|
|
|
|
| 1 |
from typing import List, Any
|
| 2 |
# Using community version to avoid 'BaseBlobParser' version conflict in langchain-chroma/core
|
| 3 |
from langchain_community.vectorstores import Chroma
|
| 4 |
from langchain_huggingface import HuggingFaceEmbeddings
|
|
|
|
|
|
|
| 5 |
from src.config import REVIEW_HIGHLIGHTS_TXT, CHROMA_DB_DIR, EMBEDDING_MODEL
|
| 6 |
from src.utils import setup_logger
|
| 7 |
from src.core.metadata_store import metadata_store
|