ymlin105 commited on
Commit
3f281f1
·
1 Parent(s): 950f43a

chore: remove legacy files and scripts no longer part of the main architecture

Browse files
.cursorrules ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ You are a Senior Python Systems Architect specializing in Machine Learning Engineering.
2
+ Your goal is to refactor a research-prototype code base into a production-grade system.
3
+
4
+ Guidelines:
5
+ 1. **Code Structure**: Follow Clean Architecture principles. Separate concerns strictly between Data Access, Business Logic, and Interface.
6
+ 2. **Type Hinting**: All new or refactored functions MUST have Python type hints and docstrings.
7
+ 3. **No "Glue Scripts"**: Avoid using `subprocess.run` to call other Python scripts. Import classes and call methods instead.
8
+ 4. **Error Handling**: Use specific exception handling, not bare `except Exception`.
9
+ 5. **Paths**: Always use `pathlib.Path`, never `os.path.join`.
10
+ 6. **Refactoring Safety**: When refactoring, ensure existing logic (feature engineering, recall calculation) is preserved unless explicitly asked to simplify.
legacy/README.md DELETED
@@ -1,10 +0,0 @@
1
- # Legacy — Not part of main architecture
2
-
3
- Code moved here is preserved but not used in the main flow (src.main FastAPI + React).
4
-
5
- | File | Note |
6
- |:---|:---|
7
- | app.py | Gradio UI (replaced by React + FastAPI) |
8
- | agent/ | Shopping agent (broken imports, not used) |
9
- | deploy.sh | Old Hugging Face deployment script |
10
- | download_fix.py | Temporary fix script |
 
 
 
 
 
 
 
 
 
 
 
legacy/agent/agent_core.py DELETED
@@ -1,55 +0,0 @@
1
- from intent_parser import IntentParser
2
- from rag_retriever import ProductRetriever
3
- from dialogue_manager import DialogueManager
4
- from llm_generator import LLMGenerator
5
- import os
6
-
7
- class ShoppingAgent:
8
- def __init__(self, index_path: str, metadata_path: str, llm_model: str = None):
9
- self.parser = IntentParser()
10
- self.retriever = ProductRetriever(index_path, metadata_path)
11
- self.dialogue_manager = DialogueManager()
12
- self.llm = LLMGenerator(model_name=llm_model) # Defaults to mock
13
-
14
- def process_query(self, query: str):
15
- print(f"\nUser: {query}")
16
-
17
- # 1. Parse Intent
18
- intent = self.parser.parse(query)
19
- # print(f"[Debug] Intent: {intent}")
20
-
21
- # 2. Enrich Query (incorporating history could happen here)
22
- search_query = query
23
- if intent['category']:
24
- search_query += f" {intent['category']}"
25
-
26
- # 3. Retrieve
27
- results = self.retriever.search(search_query, k=3)
28
-
29
- # 4. Generate Response using LLM + History
30
- history_str = self.dialogue_manager.get_context_string()
31
- response = self.llm.generate_response(query, results, history_str)
32
-
33
- # 5. Update Memory
34
- self.dialogue_manager.add_turn(query, response)
35
-
36
- print("[Agent]:")
37
- print(response)
38
- return response
39
-
40
- def reset(self):
41
- self.dialogue_manager.clear_history()
42
-
43
- if __name__ == "__main__":
44
- if not os.path.exists("data/product_index.faiss"):
45
- print("Index not found. Please run rag_indexer.py first.")
46
- else:
47
- # Pass "mock" to force CPU-friendly mock generation,
48
- # or pass a model name like "gpt2" (small) if you have 'transformers' installed to test pipeline.
49
- agent = ShoppingAgent("data/product_index.faiss", "data/product_metadata.pkl", llm_model="mock")
50
-
51
- print("--- Turn 1 ---")
52
- agent.process_query("I need a gaming laptop under $1000")
53
-
54
- print("\n--- Turn 2 ---")
55
- agent.process_query("Do you have anything cheaper?")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
legacy/agent/data_loader.py DELETED
@@ -1,61 +0,0 @@
1
- import pandas as pd
2
- import random
3
-
4
- def generate_synthetic_data(num_samples: int = 100) -> pd.DataFrame:
5
- """
6
- Generates synthetic e-commerce product data.
7
- """
8
- categories = ['Electronics', 'Clothing', 'Home & Kitchen', 'Books', 'Toys']
9
- adjectives = ['Premium', 'Budget', 'High-end', 'Durable', 'Stylish', 'Compact', 'Professional']
10
- products_map = {
11
- 'Electronics': ['Smartphone', 'Laptop', 'Headphones', 'Smartwatch', 'Camera'],
12
- 'Clothing': ['T-Shirt', 'Jeans', 'Jacket', 'Sneakers', 'Dress'],
13
- 'Home & Kitchen': ['Blender', 'Coffee Maker', 'Desk Lamp', 'Sofa', 'Curtains'],
14
- 'Books': ['Novel', 'Textbook', 'Biography', 'Cookbook', 'Comic'],
15
- 'Toys': ['Lego Set', 'Action Figure', 'Board Game', 'Puzzle', 'Doll']
16
- }
17
-
18
- data = []
19
- for i in range(num_samples):
20
- cat = random.choice(categories)
21
- prod = random.choice(products_map[cat])
22
- adj = random.choice(adjectives)
23
-
24
- title = f"{adj} {prod} {i+1}"
25
- price = round(random.uniform(10.0, 1000.0), 2)
26
- description = f"This is a {adj.lower()} {prod.lower()} perfect for your needs. It features high quality materials and modern design."
27
- features = f"Feature A, Feature B, {adj} Quality"
28
-
29
- data.append({
30
- 'product_id': f"P{str(i).zfill(4)}",
31
- 'title': title,
32
- 'category': cat,
33
- 'price': price,
34
- 'description': description,
35
- 'features': features,
36
- 'review_text': f"Great {prod}! I loved the {adj.lower()} aspect."
37
- })
38
-
39
- return pd.DataFrame(data)
40
-
41
- def load_data(file_path: str = None) -> pd.DataFrame:
42
- """
43
- Loads data from a file or generates synthetic data if path is None.
44
- """
45
- if file_path:
46
- # Check extension and load accordingly
47
- if file_path.endswith('.csv'):
48
- return pd.read_csv(file_path)
49
- elif file_path.endswith('.json'):
50
- return pd.read_json(file_path)
51
- else:
52
- raise ValueError("Unsupported file format")
53
- else:
54
- print("No file path provided. Generating synthetic data...")
55
- return generate_synthetic_data()
56
-
57
- if __name__ == "__main__":
58
- df = load_data()
59
- print(df.head())
60
- df.to_csv("synthetic_products.csv", index=False)
61
- print("Saved synthetic_products.csv")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
legacy/agent/dialogue_manager.py DELETED
@@ -1,39 +0,0 @@
1
- from typing import List, Dict
2
-
3
- class DialogueManager:
4
- def __init__(self, max_history: int = 5):
5
- self.history: List[Dict[str, str]] = []
6
- self.max_history = max_history
7
-
8
- def add_turn(self, user_input: str, system_response: str):
9
- """
10
- Adds a single turn to the history.
11
- """
12
- self.history.append({"role": "user", "content": user_input})
13
- self.history.append({"role": "assistant", "content": system_response})
14
-
15
- # Keep history within limit (rolling buffer)
16
- if len(self.history) > self.max_history * 2:
17
- self.history = self.history[-(self.max_history * 2):]
18
-
19
- def get_history(self) -> List[Dict[str, str]]:
20
- """
21
- Returns the conversation history.
22
- """
23
- return self.history
24
-
25
- def clear_history(self):
26
- """
27
- Resets the conversation.
28
- """
29
- self.history = []
30
-
31
- def get_context_string(self) -> str:
32
- """
33
- Returns history formatted as a string for simple prompts.
34
- """
35
- context = ""
36
- for turn in self.history:
37
- role = "User" if turn["role"] == "user" else "Agent"
38
- context += f"{role}: {turn['content']}\n"
39
- return context
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
legacy/agent/intent_parser.py DELETED
@@ -1,64 +0,0 @@
1
- import re
2
- from typing import Dict, Optional
3
-
4
- class IntentParser:
5
- def __init__(self):
6
- # In a real scenario, this would be an LLM-based parser
7
- pass
8
-
9
- def parse(self, query: str) -> Dict[str, Optional[str]]:
10
- """
11
- Parses the user query into structured slots.
12
- """
13
- query = query.lower()
14
-
15
- intent = {
16
- 'category': None,
17
- 'budget': None,
18
- 'style': None,
19
- 'original_query': query
20
- }
21
-
22
- # Rule-based Category Extraction
23
- categories = ['laptop', 'phone', 'smartphone', 'headphone', 'camera', 'jeans', 'shirt', 'dress', 'shoe', 'blender', 'coffee', 'lamp', 'sofa', 'desk', 'toy', 'lego', 'book', 'novel']
24
- for cat in categories:
25
- if cat in query:
26
- intent['category'] = cat
27
- break # Take the first match for now
28
-
29
- # Rule-based Budget Extraction
30
- # Look for "under $100", "cheap", "expensive", "budget"
31
- if "cheap" in query or "budget" in query:
32
- intent['budget'] = "low"
33
- elif "expensive" in query or "premium" in query:
34
- intent['budget'] = "high"
35
-
36
- match = re.search(r'under \$?(\d+)', query)
37
- if match:
38
- intent['budget'] = f"<{match.group(1)}"
39
-
40
- # Rule-based Style/Feature Extraction (naïve)
41
- # Everything else that is an adjective could be style
42
- styles = ['gaming', 'professional', 'casual', 'formal', 'black', 'red', 'blue', 'wireless', 'bluetooth']
43
- found_styles = []
44
- for style in styles:
45
- if style in query:
46
- found_styles.append(style)
47
-
48
- if found_styles:
49
- intent['style'] = ", ".join(found_styles)
50
-
51
- return intent
52
-
53
- if __name__ == "__main__":
54
- parser = IntentParser()
55
- queries = [
56
- "I want a cheap gaming laptop",
57
- "Looking for a blue dress under $50",
58
- "wireless headphones for travel"
59
- ]
60
-
61
- for q in queries:
62
- print(f"Query: {q}")
63
- print(f"Parsed: {parser.parse(q)}")
64
- print("-" * 20)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
legacy/agent/llm_generator.py DELETED
@@ -1,92 +0,0 @@
1
- from typing import List, Dict, Optional
2
- import os
3
-
4
- class LLMGenerator:
5
- def __init__(self, model_name: str = None, device: str = "cpu"):
6
- """
7
- Initialize LLM.
8
- Args:
9
- model_name: HuggingFace model name (e.g., 'meta-llama/Meta-Llama-3-8B-Instruct').
10
- If None, uses a Mock generator.
11
- device: 'cpu' or 'cuda'.
12
- """
13
- self.model_name = model_name
14
- self.device = device
15
- self.pipeline = None
16
-
17
- if self.model_name and self.model_name != "mock":
18
- try:
19
- from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
20
- import torch
21
-
22
- print(f"Loading LLM: {model_name} on {device}...")
23
- # Note: In a real script, we would handle quantization (bitsandbytes) here
24
- # based on the device capabilities we discussed.
25
- dtype = torch.float16 if device == 'cuda' else torch.float32
26
-
27
- self.pipeline = pipeline(
28
- "text-generation",
29
- model=model_name,
30
- torch_dtype=dtype,
31
- device_map="auto" if device == 'cuda' else "cpu"
32
- )
33
- except Exception as e:
34
- print(f"Failed to load model {model_name}: {e}")
35
- print("Falling back to Mock Generator.")
36
- self.model_name = "mock"
37
-
38
- def generate_response(self, user_query: str, retrieved_items: List[Dict], history_str: str) -> str:
39
- """
40
- Generates a natural language response based on context.
41
- """
42
- # 1. Format retrieved items
43
- items_str = ""
44
- for i, item in enumerate(retrieved_items):
45
- items_str += f"{i+1}. {item['title']} (${item['price']}): {item['description']}\n"
46
-
47
- # 2. Construct Prompt (Simple Template)
48
- prompt = f"""You are a helpful shopping assistant.
49
-
50
- Context History:
51
- {history_str}
52
-
53
- Retrieved Products related to the user's request:
54
- {items_str}
55
-
56
- User's Query: {user_query}
57
-
58
- Instructions:
59
- - Recommend the best products from the list above.
60
- - Explain WHY they fit the user's request (budget, style, category).
61
- - Be concise and friendly.
62
-
63
- Response:"""
64
-
65
- if self.model_name == "mock" or self.model_name is None:
66
- return self._mock_generation(items_str)
67
- else:
68
- # Real LLM Generation
69
- try:
70
- outputs = self.pipeline(
71
- prompt,
72
- max_new_tokens=200,
73
- do_sample=True,
74
- temperature=0.7,
75
- truncation=True
76
- )
77
- generated_text = outputs[0]['generated_text']
78
- # Extract only the response part if the model echos the prompt (common in base pipelines)
79
- if "Response:" in generated_text:
80
- return generated_text.split("Response:")[-1].strip()
81
- return generated_text
82
- except Exception as e:
83
- return f"[Error generating response: {e}]"
84
-
85
- def _mock_generation(self, items_str):
86
- """
87
- Fallback logic for testing without a GPU.
88
- """
89
- if not items_str:
90
- return "I couldn't find any products matching your specific criteria. Could you try different keywords?"
91
-
92
- return f"Based on your request, I found these great options:\n{items_str}\nI recommend checking the first one as it offers the best value!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
legacy/agent/rag_indexer.py DELETED
@@ -1,56 +0,0 @@
1
- import os
2
- import faiss
3
- import numpy as np
4
- import pickle
5
- from sentence_transformers import SentenceTransformer
6
- import pandas as pd
7
- try:
8
- from src.data_loader import load_data
9
- except ImportError:
10
- from data_loader import load_data # Fallback for direct execution
11
-
12
- class RAGIndexer:
13
- def __init__(self, model_name: str = 'all-MiniLM-L6-v2'):
14
- self.model = SentenceTransformer(model_name)
15
- self.index = None
16
- self.metadata = []
17
-
18
- def build_index(self, data: pd.DataFrame):
19
- """
20
- Builds the Faiss index from the product dataframe.
21
- """
22
- print("Encoding product data...")
23
- # Create a rich text representation for embedding
24
- # Title + Description + Features + Category + Price (as text)
25
- documents = data.apply(lambda x: f"{x['title']} {x['description']} Category: {x['category']} Price: {x['price']}", axis=1).tolist()
26
-
27
- embeddings = self.model.encode(documents, show_progress_bar=True)
28
- dimension = embeddings.shape[1]
29
-
30
- self.index = faiss.IndexFlatL2(dimension)
31
- self.index.add(embeddings.astype('float32'))
32
-
33
- self.metadata = data.to_dict('records')
34
- print(f"Index built with {len(self.metadata)} items.")
35
-
36
- def save(self, index_path: str, metadata_path: str):
37
- """
38
- Saves the index and metadata to disk.
39
- """
40
- if self.index:
41
- faiss.write_index(self.index, index_path)
42
- with open(metadata_path, 'wb') as f:
43
- pickle.dump(self.metadata, f)
44
- print(f"Saved index to {index_path} and metadata to {metadata_path}")
45
- else:
46
- print("No index to save.")
47
-
48
- if __name__ == "__main__":
49
- # Scaffolding run
50
- df = load_data() # Generates synthetic
51
- indexer = RAGIndexer()
52
- indexer.build_index(df)
53
-
54
- # Ensure output dir exists
55
- os.makedirs("data", exist_ok=True)
56
- indexer.save("data/product_index.faiss", "data/product_metadata.pkl")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
legacy/agent/rag_retriever.py DELETED
@@ -1,42 +0,0 @@
1
- import faiss
2
- import pickle
3
- import numpy as np
4
- from sentence_transformers import SentenceTransformer
5
- from typing import List, Dict
6
-
7
- class ProductRetriever:
8
- def __init__(self, index_path: str, metadata_path: str, model_name: str = 'all-MiniLM-L6-v2'):
9
- self.model = SentenceTransformer(model_name)
10
-
11
- print(f"Loading index from {index_path}...")
12
- self.index = faiss.read_index(index_path)
13
-
14
- print(f"Loading metadata from {metadata_path}...")
15
- with open(metadata_path, 'rb') as f:
16
- self.metadata = pickle.load(f)
17
-
18
- def search(self, query: str, k: int = 5) -> List[Dict]:
19
- """
20
- Searches for the top-k most relevant products.
21
- """
22
- query_vector = self.model.encode([query]).astype('float32')
23
- distances, indices = self.index.search(query_vector, k)
24
-
25
- results = []
26
- for i, idx in enumerate(indices[0]):
27
- if idx < len(self.metadata):
28
- item = self.metadata[idx]
29
- item['score'] = float(distances[0][i])
30
- results.append(item)
31
-
32
- return results
33
-
34
- if __name__ == "__main__":
35
- # Test run
36
- retriever = ProductRetriever("data/product_index.faiss", "data/product_metadata.pkl")
37
- query = "cheap gaming laptop"
38
- results = retriever.search(query)
39
-
40
- print(f"Query: {query}")
41
- for res in results:
42
- print(f" - {res['title']} (${res['price']}) [Score: {res['score']:.4f}]")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
legacy/app.py DELETED
@@ -1,264 +0,0 @@
1
- import gradio as gr
2
- import logging
3
- import os
4
- import requests
5
- import json
6
- from typing import List, Tuple, Any
7
- from src.utils import setup_logger
8
-
9
- # --- Configuration ---
10
- API_URL = os.getenv("API_URL", "http://localhost:6006") # Localhost via SSH Tunnel
11
-
12
- # --- Initialize Logger ---
13
- logger = setup_logger(__name__)
14
-
15
- # --- Module Initialization ---
16
- # (We no longer load model locally; we query the remote API)
17
- categories = ["All", "Fiction", "History", "Science", "Technology"] # Fallback/Mock for now
18
- tones = ["All", "Happy", "Surprising", "Angry", "Suspenseful", "Sad"]
19
-
20
- def fetch_tones():
21
- try:
22
- resp = requests.get(f"{API_URL}/tones", timeout=3)
23
- if resp.status_code == 200:
24
- data = resp.json()
25
- tns = data.get("tones") if isinstance(data, dict) else None
26
- if isinstance(tns, list) and len(tns) > 0:
27
- return tns
28
- except Exception as e:
29
- logger.warning(f"fetch_tones failed: {e}")
30
- return tones
31
-
32
- def fetch_categories():
33
- try:
34
- resp = requests.get(f"{API_URL}/categories", timeout=3)
35
- if resp.status_code == 200:
36
- data = resp.json()
37
- cats = data.get("categories") if isinstance(data, dict) else None
38
- if isinstance(cats, list) and len(cats) > 0:
39
- return cats
40
- except Exception as e:
41
- logger.warning(f"fetch_categories failed: {e}")
42
- return categories
43
-
44
- # Try to fetch real categories on startup
45
- categories = fetch_categories()
46
- tones = fetch_tones()
47
-
48
- # Initialize Shopping Agent (Mock or Real)
49
- # Note: Real agent requires FAISS index. We'll handle checks later.
50
- try:
51
- # from legacy.agent.agent_core import ShoppingAgent
52
- # shopping_agent = ShoppingAgent(...)
53
- pass
54
- except ImportError:
55
- logger.warning("Shopping Agent module not found or failed to import.")
56
-
57
- # --- Business Logic: Tab 1 (Discovery) ---
58
- def recommend_books(query: str, category: str, tone: str):
59
- """Fetch recommendations and return both gallery items and raw data."""
60
- try:
61
- if not query.strip():
62
- return [], []
63
-
64
- payload = {
65
- "query": query,
66
- "category": category if category else "All",
67
- "tone": tone if tone else "All"
68
- }
69
-
70
- logger.info(f"Sending request to {API_URL}/recommend")
71
- response = requests.post(f"{API_URL}/recommend", json=payload, timeout=25)
72
-
73
- if response.status_code == 200:
74
- data = response.json()
75
- results = data.get("recommendations", [])
76
- gallery_items = [(item["thumbnail"], f"{item['title']}\n{item['authors']}") for item in results]
77
- return gallery_items, results
78
- else:
79
- logger.error(f"API Error: {response.text}")
80
- return [], []
81
-
82
- except Exception as e:
83
- logger.error(f"Error in recommend_books: {e}")
84
- return [], []
85
-
86
-
87
- def show_book_details(evt: Any, recs: List[dict]):
88
- """Populate detail panel when a gallery item is selected and prep a QA hint."""
89
- try:
90
- if recs is None:
91
- return "", "", "", "", "", -1
92
- idx = evt.index if evt and hasattr(evt, "index") else None
93
- if idx is None or idx >= len(recs):
94
- return "", "", "", "", "", -1
95
- book = recs[idx]
96
- title_block = f"### {book['title']}\n**Authors:** {book['authors']}\n**ISBN:** {book['isbn']}"
97
- desc_block = f"**Description**\n\n{book['description']}"
98
- rank_block = f"**Rank:** #{idx + 1}" # simple positional rank
99
- comments_block = "**Reviews (sample):**\n- Exceptional pacing and character depth.\n- A must-read for this genre."
100
- qa_hint = f"Ask the assistant: Tell me more about '{book['title']}' by {book['authors']}."
101
- return title_block, rank_block, comments_block, desc_block, qa_hint, idx
102
- except Exception as e:
103
- logger.error(f"Error showing book details: {e}")
104
- return "", "", "", "", "", -1
105
-
106
- def clear_discovery():
107
- return "", "All", "All", []
108
-
109
-
110
- def add_to_favorites(selected_idx: int, recs: List[dict]):
111
- try:
112
- if selected_idx is None or selected_idx < 0 or not recs or selected_idx >= len(recs):
113
- return "Please select a book from the gallery first."
114
- book = recs[selected_idx]
115
- payload = {"user_id": "local", "isbn": book["isbn"]}
116
- resp = requests.post(f"{API_URL}/favorites/add", json=payload, timeout=8)
117
- if resp.status_code == 200:
118
- data = resp.json()
119
- return f"✅ Added to favorites: {book['title']} ({data.get('favorites_count', '?')} books in collection)"
120
- return f"❌ Failed to add: {resp.text}"
121
- except Exception as e:
122
- logger.error(f"add_to_favorites error: {e}")
123
- return "❌ Error adding to favorites. Try again later."
124
-
125
-
126
- def generate_highlights(selected_idx: int, recs: List[dict]):
127
- try:
128
- if selected_idx is None or selected_idx < 0 or not recs or selected_idx >= len(recs):
129
- return "(Hint) Please select a book from the gallery, then click Generate Highlights."
130
- book = recs[selected_idx]
131
- payload = {"isbn": book["isbn"], "user_id": "local"}
132
- resp = requests.post(f"{API_URL}/marketing/highlights", json=payload, timeout=12)
133
- if resp.status_code != 200:
134
- return "Failed to generate highlights. Try again later."
135
- data = resp.json()
136
- persona = data.get("persona", {})
137
- highlights = data.get("highlights", [])
138
- header = f"### Personalized Highlights ({book['title']})\n"
139
- persona_md = f"> Your Profile: {persona.get('summary','N/A')}\n\n" if persona else ""
140
- bullets = "\n".join([f"- {h}" for h in highlights]) if highlights else "- No highlights available"
141
- return header + persona_md + bullets
142
- except Exception as e:
143
- logger.error(f"generate_highlights error: {e}")
144
- return "Error generating highlights. Try again later."
145
-
146
- # --- Business Logic: Tab 2 (Assistant) ---
147
- def chat_response(message, history):
148
- """Answer book questions using the recommender API as a knowledge source."""
149
- try:
150
- if not message.strip():
151
- return "Please describe the book or question you have."
152
-
153
- # Use the same recommend endpoint as retrieval to ground answers
154
- payload = {"query": message, "category": "All", "tone": "All"}
155
- resp = requests.post(f"{API_URL}/recommend", json=payload, timeout=20)
156
- if resp.status_code != 200:
157
- return "Unable to retrieve book information. Try again later."
158
-
159
- data = resp.json()
160
- recs = data.get("recommendations", [])
161
- if not recs:
162
- return "No matching books found. Try a different query."
163
-
164
- top = recs[0]
165
- answer = [
166
- f"**{top.get('title','')}**",
167
- f"Author: {top.get('authors','Unknown')}",
168
- f"Summary: {top.get('description','No summary available')}"
169
- ]
170
- # If more results, suggest to check discovery tab
171
- if len(recs) > 1:
172
- answer.append("More results available in the Find Books tab.")
173
- return "\n\n".join(answer)
174
- except Exception as e:
175
- logger.error(f"chat_response error: {e}")
176
- return "Error processing your question. Try again later."
177
-
178
- # --- Business Logic: Tab 3 (Marketing) ---
179
- def generate_marketing_copy(product_name, features, target_audience):
180
- # Placeholder for Marketing Content Engine
181
- # from src.marketing.guardrails import SafetyCheck...
182
- return f"""
183
- 📣 **CALLING ALL {target_audience.upper()}!**
184
-
185
- Presenting **{product_name}** — the treasure you've been seeking.
186
-
187
- ✨ **Why you'll love it:**
188
- {features}
189
-
190
- Perfect for your collection. Add it to your shelf today.
191
- """
192
-
193
- # --- UI Construction ---
194
- with gr.Blocks(title="Paper Shelf - Book Discovery", theme=gr.themes.Soft()) as dashboard:
195
-
196
- gr.Markdown("# 📚 Paper Shelf")
197
- gr.Markdown("Intelligent book discovery powered by semantic search: **Find Books**, **Ask Questions**, **Generate Marketing Copy**.")
198
-
199
- with gr.Tabs():
200
-
201
- # --- Tab 1: Discovery ---
202
- with gr.TabItem("🔍 Find Books (Search & Recommendations)"):
203
- rec_state = gr.State([]) # store full recommendation data
204
- qa_hint = gr.State("")
205
- sel_idx = gr.State(-1)
206
- with gr.Row():
207
- with gr.Column(scale=3):
208
- q_input = gr.Textbox(label="What are you looking for?", placeholder="e.g., a mystery novel with fast pacing")
209
- with gr.Column(scale=1):
210
- cat_input = gr.Dropdown(label="Category", choices=categories, value="All")
211
- tone_input = gr.Dropdown(label="Mood/Tone", choices=tones, value="All")
212
-
213
- btn_rec = gr.Button("Find Books", variant="primary")
214
- gallery = gr.Gallery(label="Results", columns=4, height="auto")
215
- with gr.Row():
216
- with gr.Column(scale=2):
217
- title_info = gr.Markdown(label="Book Info")
218
- desc_info = gr.Markdown(label="Description")
219
- with gr.Column(scale=1):
220
- rank_info = gr.Markdown(label="Ranking")
221
- comments_info = gr.Markdown(label="Reviews")
222
- qa_hint_md = gr.Markdown(label="Ask the Assistant", value="(Click a book to see suggested questions)")
223
-
224
- with gr.Row():
225
- btn_fav = gr.Button("⭐ Add to Favorites", variant="secondary")
226
- btn_high = gr.Button("✨ Generate Highlights", variant="primary")
227
- fav_status = gr.Markdown(label="Status")
228
- highlights_md = gr.Markdown(label="Personalized Highlights")
229
-
230
- btn_rec.click(recommend_books, [q_input, cat_input, tone_input], [gallery, rec_state])
231
- gallery.select(show_book_details, [rec_state], [title_info, rank_info, comments_info, desc_info, qa_hint_md, sel_idx])
232
- btn_fav.click(add_to_favorites, [sel_idx, rec_state], [fav_status])
233
- btn_high.click(generate_highlights, [sel_idx, rec_state], [highlights_md])
234
-
235
- # --- Tab 2: AI Assistant ---
236
- with gr.TabItem("💬 Ask Questions (RAG Assistant)"):
237
- chatbot = gr.ChatInterface(
238
- fn=chat_response,
239
- examples=["Is there a mystery with time travel?", "Recommend sci-fi with female protagonists"],
240
- title="Intelligent Book Assistant",
241
- description="Search and learn about books through conversational AI."
242
- )
243
-
244
- # --- Tab 3: Marketing ---
245
- with gr.TabItem("✍️ Create Marketing Copy (GenAI)"):
246
- with gr.Row():
247
- m_name = gr.Textbox(label="Book Title/Hook", value="The Hobbit - First Edition, Near Mint")
248
- m_feat = gr.Textbox(label="Key Features/Condition", value="Near mint condition, no markings, ships worldwide")
249
- m_aud = gr.Textbox(label="Target Audience", value="Fantasy enthusiasts, collectors")
250
-
251
- btn_gen = gr.Button("Generate Listing", variant="primary")
252
- m_out = gr.Markdown(label="Generated Copy")
253
-
254
- btn_gen.click(generate_marketing_copy, [m_name, m_feat, m_aud], m_out)
255
-
256
- if __name__ == "__main__":
257
- import os
258
- assets_path = os.path.join(os.path.dirname(__file__), "assets")
259
- dashboard.launch(
260
- server_name="0.0.0.0",
261
- server_port=7860,
262
- allowed_paths=[assets_path],
263
- share=True
264
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
legacy/deploy.sh DELETED
@@ -1,47 +0,0 @@
1
- #!/bin/bash
2
-
3
- echo "🚀 准备部署到 Hugging Face Spaces..."
4
-
5
- # 检查必要文件
6
- echo "📋 检查必要文件..."
7
- required_files=("gradio-dashboard.py" "books_with_emotions.csv" "books_descriptions.txt" "cover-not-found.jpg" "requirements.txt")
8
-
9
- for file in "${required_files[@]}"; do
10
- if [ -f "$file" ]; then
11
- echo "✅ $file 存在"
12
- else
13
- echo "❌ $file 缺失"
14
- exit 1
15
- fi
16
- done
17
-
18
- # 重命名主文件为 app.py (Hugging Face 标准)
19
- if [ -f "gradio-dashboard.py" ]; then
20
- cp gradio-dashboard.py app.py
21
- echo "✅ 已创建 app.py"
22
- fi
23
-
24
- # 检查 Git 状态
25
- echo "📝 检查 Git 状态..."
26
- if [ -d ".git" ]; then
27
- echo "✅ Git 仓库已初始化"
28
- git status
29
- else
30
- echo "⚠️ 未检测到 Git 仓库,请先运行:"
31
- echo " git init"
32
- echo " git add ."
33
- echo " git commit -m '准备部署'"
34
- echo " git remote add origin https://github.com/你的用户名/book-recommender.git"
35
- echo " git push -u origin main"
36
- fi
37
-
38
- echo ""
39
- echo "🎯 下一步操作:"
40
- echo "1. 访问 https://huggingface.co/spaces"
41
- echo "2. 点击 'Create new Space'"
42
- echo "3. 选择 'Gradio' SDK"
43
- echo "4. 连接你的 GitHub 仓库"
44
- echo "5. 在 Settings 中添加 HUGGINGFACEHUB_API_TOKEN"
45
- echo "6. 等待自动部署完成"
46
- echo ""
47
- echo "📖 详细说明请查看 HUGGINGFACE_DEPLOYMENT.md"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
legacy/download_fix.py DELETED
@@ -1,11 +0,0 @@
1
- import os
2
- os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
3
- from huggingface_hub import snapshot_download
4
-
5
- print("🚀 Downloading model from hf-mirror...")
6
- snapshot_download(
7
- repo_id="sentence-transformers/all-MiniLM-L6-v2",
8
- ignore_patterns=["*.bin", "*.h5", "*.ot"], # 只下载 safetensors,省流
9
- resume_download=True
10
- )
11
- print("✅ Download Complete!")
 
 
 
 
 
 
 
 
 
 
 
 
scripts/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Pipeline scripts package
scripts/data/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Data processing scripts
scripts/data/build_books_basic_info.py CHANGED
@@ -1,48 +1,48 @@
1
- import pandas as pd
2
  import csv
 
 
 
 
3
 
4
- # 读取原始数据,遇到格式错误行自动跳过,保证流程不中断
5
- books_data = pd.read_csv(
6
- "data/books_data.csv",
7
- engine="python",
8
- quotechar='"',
9
- escapechar='\\',
10
- on_bad_lines='skip' # pandas >=1.3
11
- )
12
- ratings = pd.read_csv("data/Books_rating.csv", engine="python", quotechar='"', escapechar='\\', on_bad_lines='skip')
13
-
14
- # 只保留有用字段
15
- books_cols = [
16
- "Title", "description", "authors", "image", "publisher", "publishedDate", "categories"
17
- ]
18
- books_data = books_data[books_cols]
19
-
20
- # 只保留 Title, Id, review/score 字段用于合并
21
- ratings_cols = ["Title", "Id", "review/score"]
22
- ratings = ratings[ratings_cols]
23
-
24
- # 去重
25
- ratings = ratings.drop_duplicates(subset=["Title"])
26
-
27
- # 合并,左连接,保留 books_data 所有行
28
- merged = books_data.merge(ratings, on="Title", how="left")
29
-
30
- # 重命名字段
31
- merged = merged.rename(columns={
32
- "Id": "isbn10",
33
- "Title": "title",
34
- "authors": "authors",
35
- "description": "description",
36
- "image": "image",
37
- "publisher": "publisher",
38
- "publishedDate": "publishedDate",
39
- "categories": "categories",
40
- "review/score": "average_rating"
41
- })
42
-
43
- # 生成 isbn13(如有更复杂规则可补充,这里仅占位)
44
- merged["isbn13"] = None # 可后续补充isbn13生成逻辑
45
-
46
- # 保存新表,强制所有字段加引号,防止description等字段被截断
47
- merged.to_csv("data/books_basic_info.csv", index=False, quoting=csv.QUOTE_ALL, quotechar='"', escapechar='\\')
48
- print("已生成 data/books_basic_info.csv,包含基础书籍信息字段。")
 
 
1
  import csv
2
+ import logging
3
+ from pathlib import Path
4
+
5
+ import pandas as pd
6
 
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ def run(
11
+ books_path: Path = Path("data/books_data.csv"),
12
+ ratings_path: Path = Path("data/Books_rating.csv"),
13
+ output_path: Path = Path("data/books_basic_info.csv"),
14
+ ) -> None:
15
+ """Build books basic info from raw data. Callable from Pipeline."""
16
+ books_data = pd.read_csv(
17
+ str(books_path),
18
+ engine="python",
19
+ quotechar='"',
20
+ escapechar='\\',
21
+ on_bad_lines='skip',
22
+ )
23
+ ratings = pd.read_csv(
24
+ str(ratings_path),
25
+ engine="python",
26
+ quotechar='"',
27
+ escapechar='\\',
28
+ on_bad_lines='skip',
29
+ )
30
+
31
+ books_cols = ["Title", "description", "authors", "image", "publisher", "publishedDate", "categories"]
32
+ books_data = books_data[books_cols]
33
+ ratings = ratings[["Title", "Id", "review/score"]].drop_duplicates(subset=["Title"])
34
+ merged = books_data.merge(ratings, on="Title", how="left")
35
+ merged = merged.rename(columns={
36
+ "Id": "isbn10", "Title": "title", "authors": "authors", "description": "description",
37
+ "image": "image", "publisher": "publisher", "publishedDate": "publishedDate",
38
+ "categories": "categories", "review/score": "average_rating"
39
+ })
40
+ merged["isbn13"] = None
41
+
42
+ merged.to_csv(str(output_path), index=False, quoting=csv.QUOTE_ALL, quotechar='"', escapechar='\\')
43
+ logger.info("Saved %s", output_path)
44
+
45
+
46
+ if __name__ == "__main__":
47
+ logging.basicConfig(level=logging.INFO)
48
+ run()
 
 
 
scripts/data/clean_data.py CHANGED
@@ -226,23 +226,21 @@ def analyze_data_quality(df: pd.DataFrame, text_columns: list) -> dict:
226
  return stats
227
 
228
 
229
- def main():
230
- parser = argparse.ArgumentParser(description="Clean text data in books dataset")
231
- parser.add_argument("--input", type=Path, default=Path("data/books_processed.csv"))
232
- parser.add_argument("--output", type=Path, default=None)
233
- parser.add_argument("--dry-run", action="store_true", help="Analyze without saving")
234
- parser.add_argument("--backup", action="store_true", help="Create backup before overwriting")
235
- args = parser.parse_args()
236
-
237
- if args.output is None:
238
- args.output = args.input # Overwrite by default
239
-
240
- if not args.input.exists():
241
- raise FileNotFoundError(f"Input file not found: {args.input}")
242
 
243
- # Load data
244
- logger.info(f"Loading data from {args.input}")
245
- df = pd.read_csv(args.input)
246
  logger.info(f"Loaded {len(df):,} records")
247
 
248
  # Define columns to clean
@@ -261,31 +259,42 @@ def main():
261
  for col, s in stats_before.items():
262
  logger.info(f" {col}: {s['has_html']} HTML, {s['has_url']} URLs, avg_len={s['avg_length']:.0f}")
263
 
264
- if args.dry_run:
265
  logger.info("\n[DRY RUN] No changes will be saved")
266
  return
267
-
268
- # Clean
269
  logger.info("\n🧹 Cleaning data...")
270
  df = clean_dataframe(df, text_columns, max_lengths)
271
-
272
- # Analyze after
273
  logger.info("\n📊 Data quality AFTER cleaning:")
274
  stats_after = analyze_data_quality(df, text_columns)
275
  for col, s in stats_after.items():
276
  logger.info(f" {col}: {s['has_html']} HTML, {s['has_url']} URLs, avg_len={s['avg_length']:.0f}")
277
-
278
- # Backup if requested
279
- if args.backup and args.output.exists():
280
- backup_path = args.output.with_suffix('.csv.bak')
281
  logger.info(f"Creating backup: {backup_path}")
282
- args.output.rename(backup_path)
283
-
284
- # Save
285
- logger.info(f"\n💾 Saving to {args.output}")
286
- df.to_csv(args.output, index=False)
287
  logger.info("✅ Done!")
288
 
289
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
  if __name__ == "__main__":
291
  main()
 
226
  return stats
227
 
228
 
229
+ def run(
230
+ backup: bool = False,
231
+ input_path: Optional[Path] = None,
232
+ output_path: Optional[Path] = None,
233
+ dry_run: bool = False,
234
+ ) -> None:
235
+ """Clean text data. Callable from Pipeline."""
236
+ input_path = input_path or Path("data/books_processed.csv")
237
+ output_path = output_path or input_path
238
+
239
+ if not input_path.exists():
240
+ raise FileNotFoundError(f"Input file not found: {input_path}")
 
241
 
242
+ logger.info(f"Loading data from {input_path}")
243
+ df = pd.read_csv(input_path)
 
244
  logger.info(f"Loaded {len(df):,} records")
245
 
246
  # Define columns to clean
 
259
  for col, s in stats_before.items():
260
  logger.info(f" {col}: {s['has_html']} HTML, {s['has_url']} URLs, avg_len={s['avg_length']:.0f}")
261
 
262
+ if dry_run:
263
  logger.info("\n[DRY RUN] No changes will be saved")
264
  return
265
+
 
266
  logger.info("\n🧹 Cleaning data...")
267
  df = clean_dataframe(df, text_columns, max_lengths)
268
+
 
269
  logger.info("\n📊 Data quality AFTER cleaning:")
270
  stats_after = analyze_data_quality(df, text_columns)
271
  for col, s in stats_after.items():
272
  logger.info(f" {col}: {s['has_html']} HTML, {s['has_url']} URLs, avg_len={s['avg_length']:.0f}")
273
+
274
+ if backup and output_path.exists():
275
+ backup_path = output_path.with_suffix('.csv.bak')
 
276
  logger.info(f"Creating backup: {backup_path}")
277
+ output_path.rename(backup_path)
278
+
279
+ logger.info(f"\n💾 Saving to {output_path}")
280
+ df.to_csv(output_path, index=False)
 
281
  logger.info("✅ Done!")
282
 
283
 
284
+ def main():
285
+ parser = argparse.ArgumentParser(description="Clean text data in books dataset")
286
+ parser.add_argument("--input", type=Path, default=Path("data/books_processed.csv"))
287
+ parser.add_argument("--output", type=Path, default=None)
288
+ parser.add_argument("--dry-run", action="store_true", help="Analyze without saving")
289
+ parser.add_argument("--backup", action="store_true", help="Create backup before overwriting")
290
+ args = parser.parse_args()
291
+ run(
292
+ backup=args.backup,
293
+ input_path=args.input,
294
+ output_path=args.output or args.input,
295
+ dry_run=args.dry_run,
296
+ )
297
+
298
+
299
  if __name__ == "__main__":
300
  main()
scripts/data/generate_emotions.py CHANGED
@@ -76,72 +76,32 @@ def scores_to_vector(scores: List[Dict[str, float]]) -> Dict[str, float]:
76
  return mapped
77
 
78
 
79
- def main():
80
- ap = argparse.ArgumentParser(description="Generate emotion scores from descriptions")
81
- ap.add_argument("--input", type=Path, default=Path("data/books_processed.csv"))
82
- ap.add_argument("--output", type=Path, default=Path("data/books_processed.csv"))
83
- ap.add_argument("--batch-size", type=int, default=16)
84
- ap.add_argument("--max-rows", type=int, default=None, help="Optional cap for debugging")
85
- ap.add_argument("--device", default=None, help="'mps' for Apple GPU, CUDA device id, or omit for CPU")
86
- ap.add_argument("--checkpoint", type=int, default=5000, help="Rows between checkpoint writes")
87
- ap.add_argument("--resume", action="store_true", help="Resume if output exists (skip rows with scores)")
88
- args = ap.parse_args()
89
-
90
- if not args.input.exists():
91
- raise FileNotFoundError(f"Input file not found: {args.input}")
92
-
93
- logger.info("Loading data from %s", args.input)
94
- df = pd.read_csv(args.input)
95
  if "description" not in df.columns:
96
  raise ValueError("Input CSV must have a 'description' column")
97
 
98
- if args.max_rows:
99
- df = df.head(args.max_rows)
100
- logger.info("Truncated to %d rows for max_rows", len(df))
101
-
102
- n = len(df)
103
- # Normalize device arg
104
- dev: str | int | None
105
- if args.device is None:
106
- dev = None
107
- else:
108
- if isinstance(args.device, str) and args.device.lower() == "mps":
109
- dev = "mps"
110
- else:
111
- try:
112
- dev = int(args.device)
113
- except ValueError:
114
- dev = None
115
- model = load_model(dev)
116
-
117
- # Prepare containers
118
  for col in TARGET_LABELS:
119
  if col not in df.columns:
120
  df[col] = 0.0
121
 
122
- # Resume support: if output exists, and resume flag set, load scores
123
- if args.resume and args.output.exists():
124
- logger.info("Resume enabled: loading existing output from %s", args.output)
125
- df_prev = pd.read_csv(args.output)
126
- for col in TARGET_LABELS:
127
- if col in df_prev.columns:
128
- df[col] = df_prev[col]
129
-
130
  texts = df["description"].fillna("").astype(str).tolist()
131
- batch = args.batch_size
132
- checkpoint = max(1, args.checkpoint)
133
-
134
- logger.info("Scoring %d descriptions (batch=%d, checkpoint=%d)...", n, batch, checkpoint)
135
- total_batches = (n + batch - 1) // batch
136
- for bidx, start in enumerate(tqdm(range(0, n, batch), total=total_batches)):
137
- end = min(start + batch, n)
138
-
139
- # Skip already-computed rows when resuming (all scores > 0)
140
- if args.resume:
141
- existing = df.loc[start:end-1, TARGET_LABELS].values
142
- if np.all(existing > 0):
143
- continue
144
 
 
 
 
145
  chunk = texts[start:end]
146
  outputs = model(chunk, truncation=True, max_length=512, top_k=None)
147
  for i, out in enumerate(outputs):
@@ -150,13 +110,29 @@ def main():
150
  for col in TARGET_LABELS:
151
  df.at[idx, col] = vec[col]
152
 
153
- # periodic checkpoint write
154
- if (start > 0) and ((start % checkpoint) == 0):
155
- df.to_csv(args.output, index=False)
156
 
157
- logger.info("Writing to %s", args.output)
158
- df.to_csv(args.output, index=False)
159
- logger.info("Done. Example row: %s", df.head(1)[TARGET_LABELS].to_dict(orient="records"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
 
162
  if __name__ == "__main__":
 
76
  return mapped
77
 
78
 
79
+ def run(
80
+ input_path: Path = Path("data/books_processed.csv"),
81
+ output_path: Path = Path("data/books_processed.csv"),
82
+ batch_size: int = 16,
83
+ device=None,
84
+ ) -> None:
85
+ """Generate emotion scores. Callable from Pipeline."""
86
+ if not input_path.exists():
87
+ raise FileNotFoundError(f"Input file not found: {input_path}")
88
+
89
+ logger.info("Loading data from %s", input_path)
90
+ df = pd.read_csv(input_path)
 
 
 
 
91
  if "description" not in df.columns:
92
  raise ValueError("Input CSV must have a 'description' column")
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  for col in TARGET_LABELS:
95
  if col not in df.columns:
96
  df[col] = 0.0
97
 
98
+ model = load_model(device)
 
 
 
 
 
 
 
99
  texts = df["description"].fillna("").astype(str).tolist()
100
+ n = len(df)
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
+ logger.info("Scoring %d descriptions...", n)
103
+ for start in tqdm(range(0, n, batch_size)):
104
+ end = min(start + batch_size, n)
105
  chunk = texts[start:end]
106
  outputs = model(chunk, truncation=True, max_length=512, top_k=None)
107
  for i, out in enumerate(outputs):
 
110
  for col in TARGET_LABELS:
111
  df.at[idx, col] = vec[col]
112
 
113
+ logger.info("Writing to %s", output_path)
114
+ df.to_csv(output_path, index=False)
115
+
116
 
117
+ def main():
118
+ ap = argparse.ArgumentParser(description="Generate emotion scores from descriptions")
119
+ ap.add_argument("--input", type=Path, default=Path("data/books_processed.csv"))
120
+ ap.add_argument("--output", type=Path, default=Path("data/books_processed.csv"))
121
+ ap.add_argument("--batch-size", type=int, default=16)
122
+ ap.add_argument("--max-rows", type=int, default=None, help="Optional cap for debugging")
123
+ ap.add_argument("--device", default=None, help="'mps' for Apple GPU, CUDA device id, or omit for CPU")
124
+ ap.add_argument("--checkpoint", type=int, default=5000, help="Rows between checkpoint writes")
125
+ ap.add_argument("--resume", action="store_true", help="Resume if output exists (skip rows with scores)")
126
+ args = ap.parse_args()
127
+ dev = None
128
+ if args.device:
129
+ dev = "mps" if str(args.device).lower() == "mps" else (int(args.device) if str(args.device).isdigit() else None)
130
+ run(
131
+ input_path=args.input,
132
+ output_path=args.output,
133
+ batch_size=args.batch_size,
134
+ device=dev,
135
+ )
136
 
137
 
138
  if __name__ == "__main__":
scripts/data/generate_tags.py CHANGED
@@ -116,6 +116,30 @@ def compute_tags(corpus: List[str], top_n: int, max_features: int, min_df: int,
116
  return tags
117
 
118
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  def main():
120
  parser = argparse.ArgumentParser(description="Generate per-book tags from descriptions")
121
  parser.add_argument("--input", type=Path, default=Path("data/books_processed.csv"))
@@ -125,29 +149,15 @@ def main():
125
  parser.add_argument("--min-df", type=int, default=5)
126
  parser.add_argument("--max-df", type=float, default=0.5)
127
  args = parser.parse_args()
128
-
129
- if not args.input.exists():
130
- raise FileNotFoundError(f"Input file not found: {args.input}")
131
-
132
- logger.info("Loading data from %s", args.input)
133
- df = pd.read_csv(args.input)
134
- if "description" not in df.columns:
135
- raise ValueError("Input CSV must have a 'description' column")
136
-
137
- corpus = [normalize_text(x) for x in df["description"].fillna("").astype(str).tolist()]
138
- tags = compute_tags(
139
- corpus,
140
  top_n=args.top_n,
141
  max_features=args.max_features,
142
  min_df=args.min_df,
143
  max_df=args.max_df,
144
  )
145
 
146
- df["tags"] = tags
147
- logger.info("Writing tagged data to %s", args.output)
148
- df.to_csv(args.output, index=False)
149
- logger.info("Done. Sample tags: %s", tags[0:3])
150
-
151
 
152
  if __name__ == "__main__":
153
  main()
 
116
  return tags
117
 
118
 
119
+ def run(
120
+ input_path: Path = Path("data/books_processed.csv"),
121
+ output_path: Path = Path("data/books_processed.csv"),
122
+ top_n: int = 8,
123
+ max_features: int = 60000,
124
+ min_df: int = 5,
125
+ max_df: float = 0.5,
126
+ ) -> None:
127
+ """Generate per-book tags. Callable from Pipeline."""
128
+ if not input_path.exists():
129
+ raise FileNotFoundError(f"Input file not found: {input_path}")
130
+
131
+ logger.info("Loading data from %s", input_path)
132
+ df = pd.read_csv(input_path)
133
+ if "description" not in df.columns:
134
+ raise ValueError("Input CSV must have a 'description' column")
135
+
136
+ corpus = [normalize_text(x) for x in df["description"].fillna("").astype(str).tolist()]
137
+ tags = compute_tags(corpus, top_n=top_n, max_features=max_features, min_df=min_df, max_df=max_df)
138
+ df["tags"] = tags
139
+ logger.info("Writing tagged data to %s", output_path)
140
+ df.to_csv(output_path, index=False)
141
+
142
+
143
  def main():
144
  parser = argparse.ArgumentParser(description="Generate per-book tags from descriptions")
145
  parser.add_argument("--input", type=Path, default=Path("data/books_processed.csv"))
 
149
  parser.add_argument("--min-df", type=int, default=5)
150
  parser.add_argument("--max-df", type=float, default=0.5)
151
  args = parser.parse_args()
152
+ run(
153
+ input_path=args.input,
154
+ output_path=args.output,
 
 
 
 
 
 
 
 
 
155
  top_n=args.top_n,
156
  max_features=args.max_features,
157
  min_df=args.min_df,
158
  max_df=args.max_df,
159
  )
160
 
 
 
 
 
 
161
 
162
  if __name__ == "__main__":
163
  main()
scripts/data/split_rec_data.py CHANGED
@@ -4,7 +4,7 @@
4
 
5
  划分策略: 时序划分 (Leave-Last-Out)
6
  - 每个用户的最后一次评分 → test
7
- - 每个用户的倒数第二次评分 → val
8
  - 其余评分 → train
9
 
10
  只保留评分 >= 3 次的用户 (有足够历史)
@@ -15,128 +15,71 @@ import numpy as np
15
  from pathlib import Path
16
  from tqdm import tqdm
17
  import time
 
18
 
19
- print('='*60)
20
- print('推荐系统数据划分')
21
- print('='*60)
22
-
23
- start_time = time.time()
24
-
25
- # 路径配置
26
- DATA_PATH = Path('data/raw/Books_rating.csv')
27
- OUTPUT_DIR = Path('data/rec')
28
- OUTPUT_DIR.mkdir(exist_ok=True)
29
-
30
- # ==================== 1. 加载数据 ====================
31
- print('\n[1/5] 加载原始评论数据...')
32
- # 原始列: Id (ISBN), User_id (用户), review/score, review/time, review/text
33
- df = pd.read_csv(DATA_PATH, usecols=['Id', 'User_id', 'review/score', 'review/time', 'review/text'])
34
- df.columns = ['isbn', 'user_id', 'rating', 'timestamp', 'review']
35
-
36
- print(f' 原始记录数: {len(df):,}')
37
- print(f' 用户数: {df["user_id"].nunique():,}')
38
- print(f' 书籍数: {df["isbn"].nunique():,}')
39
-
40
- # ==================== 2. 数据清洗 ====================
41
- print('\n[2/5] 数据清洗...')
42
-
43
- # 去除重复评分 (同一用户对同一本书)
44
- df = df.drop_duplicates(subset=['user_id', 'isbn'], keep='last')
45
- print(f' 去重后: {len(df):,}')
46
-
47
- # 去除缺失值
48
- df = df.dropna(subset=['rating', 'timestamp'])
49
- print(f' 去除缺失后: {len(df):,}')
50
-
51
- # 过滤低质量评分 (可选: 只保留 rating > 0)
52
- df = df[df['rating'] > 0]
53
- print(f' 过滤低质量后: {len(df):,}')
54
-
55
- # ==================== 3. 用户筛选 ====================
56
- print('\n[3/5] 筛选活跃用户...')
57
-
58
- # 统计每个用户的评分数
59
- user_counts = df.groupby('user_id').size()
60
- print(f' 评分分布:')
61
- print(f' 1次: {(user_counts == 1).sum():,}')
62
- print(f' 2次: {(user_counts == 2).sum():,}')
63
- print(f' 3-5次: {((user_counts >= 3) & (user_counts <= 5)).sum():,}')
64
- print(f' 5-10次: {((user_counts > 5) & (user_counts <= 10)).sum():,}')
65
- print(f' 10+次: {(user_counts > 10).sum():,}')
66
-
67
- # 只保留评分 >= 3 次的用户 (需要 1 train + 1 val + 1 test)
68
- active_users = user_counts[user_counts >= 3].index
69
- df = df[df['user_id'].isin(active_users)]
70
- print(f' 活跃用户 (>=3次): {len(active_users):,}')
71
- print(f' 筛选后记录数: {len(df):,}')
72
-
73
- # ==================== 4. 时序划分 ====================
74
- print('\n[4/5] 时序划分 (Leave-Last-Out)...')
75
-
76
- # 按用户和时间排序
77
- df = df.sort_values(['user_id', 'timestamp'])
78
-
79
- train_list = []
80
- val_list = []
81
- test_list = []
82
-
83
- for user_id, group in tqdm(df.groupby('user_id'), desc=' 划分用户'):
84
- # 按时间排序
85
- group = group.sort_values('timestamp')
86
- n = len(group)
87
-
88
- # 最后一条 → test
89
- test_list.append(group.iloc[-1])
90
-
91
- # 倒数第二条 → val
92
- val_list.append(group.iloc[-2])
93
-
94
- # 其余 → train
95
- train_list.extend(group.iloc[:-2].to_dict('records'))
96
-
97
- # 转换为 DataFrame
98
- train_df = pd.DataFrame(train_list)
99
- val_df = pd.DataFrame(val_list)
100
- test_df = pd.DataFrame(test_list)
101
-
102
- print(f' 训练集: {len(train_df):,} ({len(train_df)/len(df)*100:.1f}%)')
103
- print(f' 验证集: {len(val_df):,} ({len(val_df)/len(df)*100:.1f}%)')
104
- print(f' 测试集: {len(test_df):,} ({len(test_df)/len(df)*100:.1f}%)')
105
-
106
- # ==================== 5. 保存数据 ====================
107
- print('\n[5/5] 保存数据...')
108
-
109
- train_df.to_csv(OUTPUT_DIR / 'train.csv', index=False)
110
- val_df.to_csv(OUTPUT_DIR / 'val.csv', index=False)
111
- test_df.to_csv(OUTPUT_DIR / 'test.csv', index=False)
112
-
113
- # 保存用户列表 (用于后续评估)
114
- active_users_df = pd.DataFrame({'user_id': active_users})
115
- active_users_df.to_csv(OUTPUT_DIR / 'active_users.csv', index=False)
116
-
117
- # 保存统计信息
118
- stats = {
119
- 'total_records': len(df),
120
- 'train_records': len(train_df),
121
- 'val_records': len(val_df),
122
- 'test_records': len(test_df),
123
- 'active_users': len(active_users),
124
- 'books': df['isbn'].nunique(),
125
- }
126
-
127
- with open(OUTPUT_DIR / 'stats.txt', 'w') as f:
128
- for k, v in stats.items():
129
- f.write(f'{k}: {v:,}\n')
130
-
131
- elapsed = time.time() - start_time
132
-
133
- print('\n' + '='*60)
134
- print('✅ 数据划分完成!')
135
- print('='*60)
136
- print(f'输出目录: {OUTPUT_DIR}')
137
- print(f' - train.csv: {len(train_df):,} 条')
138
- print(f' - val.csv: {len(val_df):,} 条')
139
- print(f' - test.csv: {len(test_df):,} 条')
140
- print(f' - active_users.csv: {len(active_users):,} 用户')
141
- print(f'执行时间: {elapsed:.1f}秒')
142
- print('='*60)
 
4
 
5
  划分策略: 时序划分 (Leave-Last-Out)
6
  - 每个用户的最后一次评分 → test
7
+ - 每个用户的倒数第二次评分 → val
8
  - 其余评分 → train
9
 
10
  只保留评分 >= 3 次的用户 (有足够历史)
 
15
  from pathlib import Path
16
  from tqdm import tqdm
17
  import time
18
+ import logging
19
 
20
+ logger = logging.getLogger(__name__)
21
+
22
+ DATA_PATH = Path("data/raw/Books_rating.csv")
23
+ OUTPUT_DIR = Path("data/rec")
24
+
25
+
26
+ def run(
27
+ data_path: Path = DATA_PATH,
28
+ output_dir: Path = OUTPUT_DIR,
29
+ ) -> None:
30
+ """Split train/val/test with Leave-Last-Out. Callable from Pipeline."""
31
+ output_dir.mkdir(parents=True, exist_ok=True)
32
+ start_time = time.time()
33
+
34
+ logger.info("Loading raw ratings...")
35
+ df = pd.read_csv(data_path, usecols=['Id', 'User_id', 'review/score', 'review/time', 'review/text'])
36
+ df.columns = ['isbn', 'user_id', 'rating', 'timestamp', 'review']
37
+ logger.info(f" Records: {len(df):,}, Users: {df['user_id'].nunique():,}, Items: {df['isbn'].nunique():,}")
38
+
39
+ logger.info("Cleaning data...")
40
+ df = df.drop_duplicates(subset=['user_id', 'isbn'], keep='last')
41
+ df = df.dropna(subset=['rating', 'timestamp'])
42
+ df = df[df['rating'] > 0]
43
+
44
+ logger.info("Filtering active users (>=3 interactions)...")
45
+ user_counts = df.groupby('user_id').size()
46
+ active_users = user_counts[user_counts >= 3].index
47
+ df = df[df['user_id'].isin(active_users)]
48
+ logger.info(f" Active users: {len(active_users):,}, Records: {len(df):,}")
49
+
50
+ logger.info("Splitting train/val/test (Leave-Last-Out)...")
51
+ df = df.sort_values(['user_id', 'timestamp'])
52
+
53
+ train_list = []
54
+ val_list = []
55
+ test_list = []
56
+
57
+ for user_id, group in tqdm(df.groupby('user_id'), desc=" Splitting"):
58
+ group = group.sort_values('timestamp')
59
+ test_list.append(group.iloc[-1])
60
+ val_list.append(group.iloc[-2])
61
+ train_list.extend(group.iloc[:-2].to_dict('records'))
62
+
63
+ train_df = pd.DataFrame(train_list)
64
+ val_df = pd.DataFrame(val_list)
65
+ test_df = pd.DataFrame(test_list)
66
+
67
+ logger.info(f" Train: {len(train_df):,}, Val: {len(val_df):,}, Test: {len(test_df):,}")
68
+
69
+ train_df.to_csv(output_dir / 'train.csv', index=False)
70
+ val_df.to_csv(output_dir / 'val.csv', index=False)
71
+ test_df.to_csv(output_dir / 'test.csv', index=False)
72
+ pd.DataFrame({'user_id': active_users}).to_csv(output_dir / 'active_users.csv', index=False)
73
+
74
+ with open(output_dir / 'stats.txt', 'w') as f:
75
+ for k, v in [('total_records', len(df)), ('train_records', len(train_df)),
76
+ ('val_records', len(val_df)), ('test_records', len(test_df)),
77
+ ('active_users', len(active_users)), ('books', df['isbn'].nunique())]:
78
+ f.write(f'{k}: {v:,}\n')
79
+
80
+ logger.info("Split complete in %.1fs", time.time() - start_time)
81
+
82
+
83
+ if __name__ == "__main__":
84
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
85
+ run()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/model/build_recall_models.py CHANGED
@@ -1,76 +1,61 @@
1
  #!/usr/bin/env python3
2
  """
3
- Build Traditional Recall Models (ItemCF, UserCF, Swing, Popularity, Item2Vec)
4
 
5
- Trains collaborative filtering, embedding-based, and popularity recall models.
6
- These are CPU-friendly and provide strong baselines.
7
 
8
  Usage:
9
  python scripts/model/build_recall_models.py
10
 
11
- Input:
12
- - data/rec/train.csv
13
-
14
- Output:
15
- - data/model/recall/itemcf.pkl (~1.4 GB)
16
- - data/model/recall/usercf.pkl (~70 MB)
17
- - data/model/recall/swing.pkl
18
- - data/model/recall/popularity.pkl
19
- - data/model/recall/item2vec.pkl
20
-
21
- Algorithms:
22
- - ItemCF: Co-rating similarity with direction weight (forward=1.0, backward=0.7)
23
- - UserCF: User similarity (Jaccard + activity penalty)
24
- - Swing: User-pair overlap weighting for substitute relationships
25
- - Popularity: Rating count with time decay
26
- - Item2Vec: Word2Vec (Skip-gram) on user interaction sequences
27
  """
28
 
29
  import sys
30
- import os
31
- sys.path.append(os.getcwd())
 
 
32
 
33
  import pandas as pd
34
  import logging
 
35
  from src.recall.itemcf import ItemCF
36
  from src.recall.usercf import UserCF
37
  from src.recall.swing import Swing
38
  from src.recall.popularity import PopularityRecall
39
  from src.recall.item2vec import Item2Vec
40
 
41
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
42
  logger = logging.getLogger(__name__)
43
 
 
 
 
 
44
  def main():
45
- logger.info("Loading training data...")
46
- df = pd.read_csv('data/rec/train.csv')
47
-
48
- # 1. ItemCF (force retrain — direction weight updated)
49
  logger.info("--- Training ItemCF ---")
50
- itemcf = ItemCF()
51
- itemcf.fit(df)
52
-
53
- # 2. UserCF
54
  logger.info("--- Training UserCF ---")
55
- usercf = UserCF()
56
- usercf.fit(df)
57
-
58
- # 3. Swing
59
  logger.info("--- Training Swing ---")
60
- swing = Swing()
61
- swing.fit(df)
62
 
63
- # 4. Popularity
64
  logger.info("--- Training Popularity ---")
65
- pop = PopularityRecall()
66
- pop.fit(df)
67
-
68
- # 5. Item2Vec
69
  logger.info("--- Training Item2Vec ---")
70
- item2vec = Item2Vec()
71
- item2vec.fit(df)
 
72
 
73
- logger.info("Recall models built and saved successfully!")
74
 
75
  if __name__ == "__main__":
76
  main()
 
1
  #!/usr/bin/env python3
2
  """
3
+ Entry script: Build recall models (ItemCF, UserCF, Swing, Popularity, Item2Vec).
4
 
5
+ All training logic lives in src/recall/*.fit(). This script only loads data,
6
+ imports models, and calls fit().
7
 
8
  Usage:
9
  python scripts/model/build_recall_models.py
10
 
11
+ Input: data/rec/train.csv (columns: user_id, isbn, rating, timestamp)
12
+ Output: data/model/recall/*.pkl, data/recall_models.db (ItemCF)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  """
14
 
15
  import sys
16
+ from pathlib import Path
17
+
18
+ # Run from project root
19
+ sys.path.insert(0, str(Path(__file__).resolve().parent.parent.parent))
20
 
21
  import pandas as pd
22
  import logging
23
+
24
  from src.recall.itemcf import ItemCF
25
  from src.recall.usercf import UserCF
26
  from src.recall.swing import Swing
27
  from src.recall.popularity import PopularityRecall
28
  from src.recall.item2vec import Item2Vec
29
 
30
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
31
  logger = logging.getLogger(__name__)
32
 
33
+ PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
34
+ TRAIN_PATH = PROJECT_ROOT / "data" / "rec" / "train.csv"
35
+
36
+
37
  def main():
38
+ logger.info("Loading training data from %s...", TRAIN_PATH)
39
+ df = pd.read_csv(TRAIN_PATH)
40
+ logger.info("Loaded %d records.", len(df))
41
+
42
  logger.info("--- Training ItemCF ---")
43
+ ItemCF().fit(df)
44
+
 
 
45
  logger.info("--- Training UserCF ---")
46
+ UserCF().fit(df)
47
+
 
 
48
  logger.info("--- Training Swing ---")
49
+ Swing().fit(df)
 
50
 
 
51
  logger.info("--- Training Popularity ---")
52
+ PopularityRecall().fit(df)
53
+
 
 
54
  logger.info("--- Training Item2Vec ---")
55
+ Item2Vec().fit(df)
56
+
57
+ logger.info("Recall models built and saved successfully.")
58
 
 
59
 
60
  if __name__ == "__main__":
61
  main()
scripts/model/train_sasrec.py CHANGED
@@ -1,204 +1,44 @@
1
  #!/usr/bin/env python3
2
  """
3
- Train SASRec Self-Attentive Sequential Recommendation Model
4
 
5
- A Transformer-based model for sequential recommendation.
6
- Predicts the next item based on user's historical interaction sequence.
7
 
8
  Usage:
9
  python scripts/model/train_sasrec.py
10
 
11
- Input:
12
- - data/rec/user_sequences.pkl
13
- - data/rec/item_map.pkl
14
-
15
- Output:
16
- - data/model/rec/sasrec_model.pth (model weights)
17
- - data/rec/user_seq_emb.pkl (user sequence embeddings)
18
-
19
- Architecture:
20
- - Self-Attention layers (Transformer encoder)
21
- - Positional embeddings
22
- - BCE loss with negative sampling
23
-
24
- Recommended:
25
- - GPU: 30 epochs, ~20 minutes
26
- - The user embeddings are used as features in LGBMRanker and as an independent recall channel
27
  """
28
 
29
  import sys
30
- import os
31
- sys.path.append(os.getcwd())
 
32
 
33
- import torch
34
- import torch.nn as nn
35
- import torch.optim as optim
36
- from torch.utils.data import Dataset, DataLoader
37
- import pickle
38
- import numpy as np
39
  import logging
40
- from tqdm import tqdm
41
- from pathlib import Path
42
- from src.model.sasrec import SASRec
43
 
44
- logging.basicConfig(level=logging.INFO)
 
 
45
  logger = logging.getLogger(__name__)
46
 
47
- class SeqDataset(Dataset):
48
- def __init__(self, seqs_dict, num_items, max_len=50):
49
- self.seqs = []
50
- self.num_items = num_items
51
-
52
- # Prepare (seq_in, target) pairs
53
- for u, s in seqs_dict.items():
54
- if len(s) < 2:
55
- continue
56
-
57
- # Pad
58
- seq_processed = [0] * max_len
59
- seq_len = min(len(s), max_len)
60
- seq_processed[-seq_len:] = s[-seq_len:]
61
-
62
- self.seqs.append(seq_processed)
63
-
64
- self.seqs = torch.LongTensor(self.seqs)
65
-
66
- def __len__(self):
67
- return len(self.seqs)
68
-
69
- def __getitem__(self, idx):
70
- seq = self.seqs[idx]
71
-
72
- # Determine pos/neg for training
73
- # Target: seq shifted right (positives)
74
- pos = np.zeros_like(seq)
75
- pos[:-1] = seq[1:]
76
-
77
- # Negatives: random sample
78
- neg = np.random.randint(1, self.num_items + 1, size=len(seq))
79
-
80
- return seq, torch.LongTensor(pos), torch.LongTensor(neg)
81
 
82
- def train_sasrec():
83
- max_len = 50
84
- hidden_dim = 64
85
- batch_size = 128
86
- epochs = 30 # Increased from 3
87
- lr = 1e-4 # Aligned with optimizer
88
-
89
- data_dir = Path('data/rec')
90
- logger.info("Loading sequences...")
91
- with open(data_dir / 'user_sequences.pkl', 'rb') as f:
92
- seqs_dict = pickle.load(f)
93
-
94
- with open(data_dir / 'item_map.pkl', 'rb') as f:
95
- item_map = pickle.load(f)
96
- num_items = len(item_map)
97
-
98
- dataset = SeqDataset(seqs_dict, num_items, max_len)
99
- dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
100
-
101
- # Check for MPS (Mac GPU) or CUDA
102
- if torch.cuda.is_available():
103
- device = torch.device('cuda')
104
- elif torch.backends.mps.is_available():
105
- device = torch.device('mps')
106
- else:
107
- device = torch.device('cpu')
108
-
109
- logger.info(f"Training on {device}")
110
-
111
- model = SASRec(num_items, max_len, hidden_dim).to(device)
112
- optimizer = optim.Adam(model.parameters(), lr=1e-4) # Reduced LR
113
-
114
- # BCE Loss for Pos/Neg
115
- criterion = nn.BCEWithLogitsLoss()
116
-
117
- model.train()
118
- for epoch in range(epochs):
119
- total_loss = 0
120
- pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")
121
- for seq, pos, neg in pbar:
122
- seq = seq.to(device)
123
- pos = pos.to(device)
124
- neg = neg.to(device)
125
-
126
- # Forward pass to get seq embeddings: [B, L, H]
127
- seq_emb = model(seq) # [B, L, H]
128
-
129
- # Mask padding (0) in targets
130
- mask = (pos != 0)
131
-
132
- # Get Item Embeddings for Pos and Neg
133
- pos_emb = model.item_emb(pos) # [B, L, H]
134
- neg_emb = model.item_emb(neg) # [B, L, H]
135
-
136
- # Calculate logits
137
- pos_logits = (seq_emb * pos_emb).sum(dim=-1)
138
- neg_logits = (seq_emb * neg_emb).sum(dim=-1)
139
-
140
- pos_logits = pos_logits[mask]
141
- neg_logits = neg_logits[mask]
142
-
143
- pos_labels = torch.ones_like(pos_logits)
144
- neg_labels = torch.zeros_like(neg_logits)
145
-
146
- loss = criterion(pos_logits, pos_labels) + criterion(neg_logits, neg_labels)
147
-
148
- optimizer.zero_grad()
149
- loss.backward()
150
-
151
- # Clip Gradient
152
- torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
153
-
154
- optimizer.step()
155
-
156
- total_loss += loss.item()
157
- pbar.set_postfix({'loss': total_loss / (pbar.n + 1)})
158
-
159
- # Save Model
160
- torch.save(model.state_dict(), data_dir / '../model/rec/sasrec_model.pth')
161
-
162
- # Save User Embeddings (Last hidden state)
163
- logger.info("Extracting User Sequence Embeddings...")
164
- model.eval()
165
- user_emb_dict = {}
166
-
167
- # Create evaluation loader (no shuffle, keep user order)
168
- # We need to map back to user_ids, so iterate dict directly
169
-
170
- # Batch processing for inference
171
- all_users = list(seqs_dict.keys())
172
-
173
- with torch.no_grad():
174
- for i in tqdm(range(0, len(all_users), batch_size)):
175
- batch_users = all_users[i : i+batch_size]
176
- batch_seqs = []
177
- for u in batch_users:
178
- s = seqs_dict[u]
179
- # Same padding logic
180
- seq_processed = [0] * max_len
181
- seq_len = min(len(s), max_len)
182
- if seq_len > 0:
183
- seq_processed[-seq_len:] = s[-seq_len:]
184
- batch_seqs.append(seq_processed)
185
-
186
- input_tensor = torch.LongTensor(batch_seqs).to(device)
187
-
188
- # Initial forward
189
- # Note: During inference, we use the FULL sequence to predict the FUTURE (Test Item)
190
- # So input is the full available history
191
-
192
- output = model(input_tensor) # [B, L, H]
193
- last_state = output[:, -1, :].cpu().numpy() # [B, H]
194
-
195
- for j, u in enumerate(batch_users):
196
- user_emb_dict[u] = last_state[j]
197
-
198
- with open(data_dir / 'user_seq_emb.pkl', 'wb') as f:
199
- pickle.dump(user_emb_dict, f)
200
-
201
- logger.info("User Seq Embeddings saved.")
202
 
203
  if __name__ == "__main__":
204
- train_sasrec()
 
1
  #!/usr/bin/env python3
2
  """
3
+ Entry script: Train SASRec sequential recommendation model.
4
 
5
+ All training logic lives in SASRecRecall.fit(). This script loads data
6
+ and calls fit().
7
 
8
  Usage:
9
  python scripts/model/train_sasrec.py
10
 
11
+ Input: data/rec/train.csv
12
+ Output: data/model/rec/sasrec_model.pth
13
+ data/rec/user_seq_emb.pkl, item_map.pkl, user_sequences.pkl
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  """
15
 
16
  import sys
17
+ from pathlib import Path
18
+
19
+ sys.path.insert(0, str(Path(__file__).resolve().parent.parent.parent))
20
 
21
+ import pandas as pd
 
 
 
 
 
22
  import logging
 
 
 
23
 
24
+ from src.recall.sasrec_recall import SASRecRecall
25
+
26
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
27
  logger = logging.getLogger(__name__)
28
 
29
+ PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
30
+ TRAIN_PATH = PROJECT_ROOT / "data" / "rec" / "train.csv"
31
+
32
+
33
+ def main():
34
+ logger.info("Loading training data from %s...", TRAIN_PATH)
35
+ df = pd.read_csv(TRAIN_PATH)
36
+ logger.info("Loaded %d records.", len(df))
37
+
38
+ model = SASRecRecall()
39
+ model.fit(df)
40
+ logger.info("SASRec training complete.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  if __name__ == "__main__":
44
+ main()
scripts/model/train_youtube_dnn.py CHANGED
@@ -1,239 +1,45 @@
1
  #!/usr/bin/env python3
2
  """
3
- Train YoutubeDNN Two-Tower Model for Candidate Retrieval
4
 
5
- A deep learning recall model using separate user and item towers.
6
- Trained with in-batch negative sampling for efficient learning.
7
 
8
  Usage:
9
  python scripts/model/train_youtube_dnn.py
10
 
11
- Input:
12
- - data/rec/user_sequences.pkl
13
- - data/rec/item_map.pkl
14
- - data/books_processed.csv (for category features)
15
-
16
- Output:
17
- - data/model/recall/youtube_dnn.pt (model weights)
18
- - data/model/recall/youtube_dnn_meta.pkl (config + mappings)
19
 
20
- Architecture:
21
- - User Tower: Embedding(history) -> Mean Pooling -> MLP
22
- - Item Tower: Embedding(item) + Embedding(category) -> MLP
23
- - Training: Contrastive loss with in-batch negatives
24
 
25
- Recommended:
26
- - GPU: 10-50 epochs, ~30 minutes
27
- - CPU: 3-5 epochs for testing only
28
- """
29
 
30
- import numpy as np
31
  import pandas as pd
32
- import torch
33
- import torch.nn as nn
34
- import torch.optim as optim
35
- from torch.utils.data import Dataset, DataLoader
36
- import pickle
37
- from pathlib import Path
38
- from tqdm import tqdm
39
- import sys
40
- import os
41
 
42
- # Add src to path
43
- sys.path.append(os.path.abspath('.'))
44
- from src.recall.youtube_dnn import YoutubeDNN
45
 
46
- # Configuration
47
- BATCH_SIZE = 512
48
- EPOCHS = 10
49
- LR = 0.001
50
- EMBED_DIM = 64
51
- MAX_HISTORY = 20
52
- DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
53
- if torch.backends.mps.is_available():
54
- DEVICE = torch.device('mps')
55
 
56
- print(f"Using device: {DEVICE}")
57
 
58
- def load_data():
59
- print("Loading data...")
60
- # Load mappings
61
- with open('data/rec/item_map.pkl', 'rb') as f:
62
- item_map = pickle.load(f)
63
- isbn_to_id = item_map
64
- id_to_isbn = {v: k for k, v in item_map.items()}
65
-
66
- # Load sequences
67
- with open('data/rec/user_sequences.pkl', 'rb') as f:
68
- user_seqs = pickle.load(f)
69
-
70
- # Load book features for category mapping
71
- books_df = pd.read_csv('data/books_processed.csv', usecols=['isbn13', 'simple_categories'])
72
- books_df['isbn'] = books_df['isbn13'].astype(str)
73
-
74
- # Create category map
75
- # Categories are often strings like 'Fiction', 'Juvenile Fiction'. We take the first one.
76
- cate_map = {'<PAD>': 0, '<UNK>': 1}
77
- item_to_cate = {}
78
-
79
- print("Building category map...")
80
- for _, row in books_df.iterrows():
81
- isbn = row['isbn']
82
- if isbn in isbn_to_id:
83
- iid = isbn_to_id[isbn]
84
- cates = str(row['simple_categories']).split(';')
85
- main_cate = cates[0].strip() if cates else 'Unknown'
86
-
87
- if main_cate not in cate_map:
88
- cate_map[main_cate] = len(cate_map)
89
-
90
- item_to_cate[iid] = cate_map[main_cate]
91
-
92
- # Default category for unknown items
93
- default_cate = cate_map.get('Unknown', 1)
94
-
95
- return user_seqs, item_to_cate, len(item_map)+1, len(cate_map), default_cate
96
 
97
- class RetrievalDataset(Dataset):
98
- def __init__(self, user_seqs, item_to_cate, default_cate, max_history=20):
99
- self.samples = []
100
- self.item_to_cate = item_to_cate
101
- self.default_cate = default_cate
102
- self.max_history = max_history
103
-
104
- print("Generating training samples...")
105
- # Leave-Last-Out:
106
- # Last item -> Test
107
- # 2nd Last -> Val
108
- # Rest -> Train
109
- # So we use items 0 to N-3 for training history generation
110
-
111
- for user, seq in tqdm(user_seqs.items()):
112
- if len(seq) < 3:
113
- continue
114
-
115
- # Use data up to the split point for training
116
- # Valid Train Set: seq[:-2]
117
- train_seq = seq[:-2]
118
-
119
- # Generate sliding window samples
120
- # minimum history length = 1
121
- for i in range(1, len(train_seq)):
122
- target = train_seq[i]
123
- history = train_seq[:i]
124
-
125
- # Truncate history
126
- if len(history) > max_history:
127
- history = history[-max_history:]
128
-
129
- self.samples.append((history, target))
130
-
131
- print(f"Total training samples: {len(self.samples)}")
132
-
133
- def __len__(self):
134
- return len(self.samples)
135
-
136
- def __getitem__(self, idx):
137
- history, target = self.samples[idx]
138
-
139
- # Padding history
140
- padded_hist = np.zeros(self.max_history, dtype=np.int64)
141
- length = min(len(history), self.max_history)
142
- if length > 0:
143
- padded_hist[:length] = history[-length:]
144
-
145
- target_cate = self.item_to_cate.get(target, self.default_cate)
146
-
147
- return torch.LongTensor(padded_hist), torch.tensor(target, dtype=torch.long), torch.tensor(target_cate, dtype=torch.long)
148
 
149
- def train():
150
- user_seqs, item_to_cate, vocab_size, cate_vocab_size, default_cate = load_data()
151
-
152
- dataset = RetrievalDataset(user_seqs, item_to_cate, default_cate, MAX_HISTORY)
153
- dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0) # mp issue on mac sometimes
154
-
155
- # Model Setup
156
- user_config = {
157
- 'vocab_size': vocab_size,
158
- 'embed_dim': EMBED_DIM,
159
- 'history_len': MAX_HISTORY
160
- }
161
- item_config = {
162
- 'vocab_size': vocab_size,
163
- 'embed_dim': EMBED_DIM,
164
- 'cate_vocab_size': cate_vocab_size,
165
- 'cate_embed_dim': 32
166
- }
167
- model_config = {
168
- 'hidden_dims': [128, 64],
169
- 'dropout': 0.1
170
- }
171
-
172
- model = YoutubeDNN(user_config, item_config, model_config).to(DEVICE)
173
- optimizer = optim.Adam(model.parameters(), lr=LR)
174
- criterion = nn.CrossEntropyLoss() # For In-Batch Negatives
175
-
176
- print("Start Training...")
177
- model.train()
178
-
179
- for epoch in range(EPOCHS):
180
- total_loss = 0
181
- steps = 0
182
-
183
- pbar = tqdm(dataloader)
184
- for history, target_item, target_cate in pbar:
185
- history = history.to(DEVICE)
186
- target_item = target_item.to(DEVICE)
187
- target_cate = target_cate.to(DEVICE)
188
-
189
- optimizer.zero_grad()
190
-
191
- # Get Vectors
192
- user_vec = model.user_tower(history) # (B, D)
193
- item_vec = model.item_tower(target_item, target_cate) # (B, D)
194
-
195
- # Normalize
196
- user_vec = nn.functional.normalize(user_vec, p=2, dim=1)
197
- item_vec = nn.functional.normalize(item_vec, p=2, dim=1)
198
-
199
- # In-Batch Negatives
200
- # logits[i][j] = user_i dot item_j
201
- # We want logits[i][i] to be high
202
- logits = torch.matmul(user_vec, item_vec.t()) # (B, B)
203
-
204
- # Temperature scaling (optional, helps convergence)
205
- logits = logits / 0.1
206
-
207
- labels = torch.arange(len(user_vec)).to(DEVICE)
208
-
209
- loss = criterion(logits, labels)
210
- loss.backward()
211
- optimizer.step()
212
-
213
- total_loss += loss.item()
214
- steps += 1
215
-
216
- pbar.set_description(f"Epoch {epoch+1} Loss: {total_loss/steps:.4f}")
217
-
218
- print(f"Epoch {epoch+1} finished. Avg Loss: {total_loss/steps:.4f}")
219
-
220
- # Save Model
221
- save_path = Path('data/model/recall')
222
- save_path.mkdir(parents=True, exist_ok=True)
223
-
224
- torch.save(model.state_dict(), save_path / 'youtube_dnn.pt')
225
-
226
- # Save metadata
227
- meta = {
228
- 'user_config': user_config,
229
- 'item_config': item_config,
230
- 'model_config': model_config,
231
- 'item_to_cate': item_to_cate
232
- }
233
- with open(save_path / 'youtube_dnn_meta.pkl', 'wb') as f:
234
- pickle.dump(meta, f)
235
-
236
- print(f"Model saved to {save_path}")
237
 
238
- if __name__ == '__main__':
239
- train()
 
1
  #!/usr/bin/env python3
2
  """
3
+ Entry script: Train YoutubeDNN Two-Tower recall model.
4
 
5
+ All training logic lives in YoutubeDNNRecall.fit(). This script loads data
6
+ and calls fit().
7
 
8
  Usage:
9
  python scripts/model/train_youtube_dnn.py
10
 
11
+ Input: data/rec/train.csv
12
+ Output: data/model/recall/youtube_dnn.pt, youtube_dnn_meta.pkl
13
+ data/rec/item_map.pkl, user_sequences.pkl
14
+ """
 
 
 
 
15
 
16
+ import sys
17
+ from pathlib import Path
 
 
18
 
19
+ sys.path.insert(0, str(Path(__file__).resolve().parent.parent.parent))
 
 
 
20
 
 
21
  import pandas as pd
22
+ import logging
23
+
24
+ from src.recall.embedding import YoutubeDNNRecall
 
 
 
 
 
 
25
 
26
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
27
+ logger = logging.getLogger(__name__)
 
28
 
29
+ PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
30
+ TRAIN_PATH = PROJECT_ROOT / "data" / "rec" / "train.csv"
31
+ BOOKS_PATH = PROJECT_ROOT / "data" / "books_processed.csv"
 
 
 
 
 
 
32
 
 
33
 
34
+ def main():
35
+ logger.info("Loading training data from %s...", TRAIN_PATH)
36
+ df = pd.read_csv(TRAIN_PATH)
37
+ logger.info("Loaded %d records.", len(df))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
+ model = YoutubeDNNRecall()
40
+ model.fit(df, books_path=BOOKS_PATH)
41
+ logger.info("YoutubeDNN training complete.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
+ if __name__ == "__main__":
45
+ main()
scripts/run_pipeline.py CHANGED
@@ -2,8 +2,8 @@
2
  """
3
  Unified Data Pipeline Runner
4
 
5
- Executes the complete data processing pipeline in correct order.
6
- Supports partial runs and validation between stages.
7
 
8
  Usage:
9
  python scripts/run_pipeline.py # Full pipeline
@@ -13,156 +13,208 @@ Usage:
13
  """
14
 
15
  import argparse
16
- import subprocess
17
  import sys
18
  import time
19
  from pathlib import Path
20
 
21
- PROJECT_ROOT = Path(__file__).parent.parent
22
-
23
-
24
- def run_script(script_path: str, description: str, args: list = None):
25
- """Run a Python script and handle errors."""
26
- print(f"\n{'='*60}")
27
- print(f"▶️ {description}")
28
- print(f" Script: {script_path}")
29
- print("=" * 60)
30
-
31
- cmd = [sys.executable, script_path]
32
- if args:
33
- cmd.extend(args)
34
-
35
- start = time.time()
36
- result = subprocess.run(cmd, cwd=PROJECT_ROOT)
37
- elapsed = time.time() - start
38
-
39
- if result.returncode != 0:
40
- print(f"\n❌ FAILED: {description} (exit code: {result.returncode})")
41
- sys.exit(1)
42
-
43
- print(f"✅ Completed in {elapsed:.1f}s")
44
- return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
 
47
  def main():
48
- parser = argparse.ArgumentParser(description="Run data pipeline")
49
- parser.add_argument("--stage", choices=[
50
- "all", "books", "rec", "index", "models"
51
- ], default="all", help="Which stage to run")
 
 
 
52
  parser.add_argument("--skip-models", action="store_true", help="Skip model training")
53
  parser.add_argument("--skip-index", action="store_true", help="Skip index building")
54
  parser.add_argument("--validate-only", action="store_true", help="Only run validation")
55
- parser.add_argument("--device", default=None, help="Device for ML models (cpu/cuda/mps)")
56
- parser.add_argument("--stacking", action="store_true", help="Enable Stacking model training (LGBM + XGB + Meta)")
57
  args = parser.parse_args()
58
-
59
- print("=" * 60)
60
- print("🚀 DATA PIPELINE RUNNER")
61
- print("=" * 60)
62
-
63
  if args.validate_only:
64
- run_script("scripts/data/validate_data.py", "Validating all data")
 
65
  return
66
-
67
- start_total = time.time()
68
-
69
- # ==========================================================================
70
- # Stage 1: Book Data Processing
71
- # ==========================================================================
72
- if args.stage in ["all", "books"]:
73
- run_script(
74
- "scripts/data/clean_data.py",
75
- "Cleaning text data (HTML, encoding, whitespace)",
76
- args=["--backup"]
77
- )
78
-
79
- run_script(
80
- "scripts/data/build_books_basic_info.py",
81
- "Building books basic info"
82
- )
83
-
84
- device_args = ["--device", args.device] if args.device else []
85
- run_script(
86
- "scripts/data/generate_emotions.py",
87
- "Generating emotion scores",
88
- args=device_args
89
- )
90
-
91
- run_script(
92
- "scripts/data/generate_tags.py",
93
- "Generating tags"
94
- )
95
-
96
- run_script(
97
- "scripts/data/chunk_reviews.py",
98
- "Chunking reviews for Small-to-Big"
99
- )
100
-
101
- # ==========================================================================
102
- # Stage 2: RecSys Data Preparation
103
- # ==========================================================================
104
- if args.stage in ["all", "rec"]:
105
- run_script(
106
- "scripts/data/split_rec_data.py",
107
- "Splitting train/val/test data"
108
- )
109
-
110
- run_script(
111
- "scripts/data/build_sequences.py",
112
- "Building user sequences"
113
- )
114
-
115
- # ==========================================================================
116
- # Stage 3: Index Building
117
- # ==========================================================================
118
- if args.stage in ["all", "index"] and not args.skip_index:
119
- run_script(
120
- "scripts/init_sqlite_db.py",
121
- "Building SQLite metadata (books.db)"
122
- )
123
- run_script(
124
- "scripts/data/init_dual_index.py",
125
- "Building chunk vector index"
126
- )
127
-
128
- # ==========================================================================
129
- # Stage 4: Model Training
130
- # ==========================================================================
131
- if args.stage in ["all", "models"] and not args.skip_models:
132
- run_script(
133
- "scripts/model/build_recall_models.py",
134
- "Building ItemCF/UserCF models"
135
- )
136
-
137
- run_script(
138
- "scripts/model/train_youtube_dnn.py",
139
- "Training YoutubeDNN (requires GPU)"
140
- )
141
-
142
- run_script(
143
- "scripts/model/train_sasrec.py",
144
- "Training SASRec (requires GPU)"
145
- )
146
-
147
- ranker_args = ["--stacking"] if args.stacking else []
148
- run_script(
149
- "scripts/model/train_ranker.py",
150
- "Training LGBMRanker (Stacking: {})".format("ON" if args.stacking else "OFF"),
151
- args=ranker_args
152
- )
153
-
154
- # ==========================================================================
155
- # Final Validation
156
- # ==========================================================================
157
- run_script(
158
- "scripts/data/validate_data.py",
159
- "Final validation"
160
  )
161
-
162
- elapsed_total = time.time() - start_total
163
- print("\n" + "=" * 60)
164
- print(f"🎉 PIPELINE COMPLETED in {elapsed_total/60:.1f} minutes")
165
- print("=" * 60)
166
 
167
 
168
  if __name__ == "__main__":
 
2
  """
3
  Unified Data Pipeline Runner
4
 
5
+ Orchestrates Data Cleaning -> Training -> Evaluation using direct Python imports.
6
+ No subprocess calls. All logic invoked via Module.run() or src classes.
7
 
8
  Usage:
9
  python scripts/run_pipeline.py # Full pipeline
 
13
  """
14
 
15
  import argparse
16
+ import logging
17
  import sys
18
  import time
19
  from pathlib import Path
20
 
21
+ # Ensure project root is on path
22
+ PROJECT_ROOT = Path(__file__).resolve().parent.parent
23
+ sys.path.insert(0, str(PROJECT_ROOT))
24
+
25
+ logging.basicConfig(
26
+ level=logging.INFO,
27
+ format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
28
+ datefmt="%H:%M:%S",
29
+ )
30
+ logger = logging.getLogger("pipeline")
31
+
32
+
33
+ class Pipeline:
34
+ """
35
+ Manages the full data pipeline: Data Cleaning -> Training -> Evaluation.
36
+
37
+ All stages use direct Python imports; no subprocess.
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ project_root: Path = PROJECT_ROOT,
43
+ device: str | None = None,
44
+ skip_models: bool = False,
45
+ skip_index: bool = False,
46
+ stacking: bool = False,
47
+ ):
48
+ self.project_root = Path(project_root)
49
+ self.data_dir = self.project_root / "data"
50
+ self.rec_dir = self.data_dir / "rec"
51
+ self.model_dir = self.data_dir / "model"
52
+ self.device = device
53
+ self.skip_models = skip_models
54
+ self.skip_index = skip_index
55
+ self.stacking = stacking
56
+
57
+ def _run_step(self, name: str, fn, *args, **kwargs):
58
+ """Run a step with timing log."""
59
+ logger.info("▶ %s", name)
60
+ start = time.time()
61
+ fn(*args, **kwargs)
62
+ logger.info(" ✓ Done in %.1fs", time.time() - start)
63
+
64
+ def run_data_cleaning(self, stage: str = "all") -> None:
65
+ """Stage 1: Book data processing."""
66
+ if stage not in ("all", "books"):
67
+ return
68
+
69
+ from scripts.data.clean_data import run as clean_run
70
+ self._run_step("Clean text data", clean_run, backup=True)
71
+
72
+ from scripts.data.build_books_basic_info import run as build_run
73
+ raw_dir = self.data_dir / "raw"
74
+ self._run_step("Build books basic info", build_run,
75
+ books_path=raw_dir / "books_data.csv",
76
+ ratings_path=raw_dir / "Books_rating.csv",
77
+ output_path=self.data_dir / "books_basic_info.csv",
78
+ )
79
+
80
+ from scripts.data.generate_emotions import run as emotions_run
81
+ self._run_step("Generate emotion scores", emotions_run, device=self.device)
82
+
83
+ from scripts.data.generate_tags import run as tags_run
84
+ self._run_step("Generate tags", tags_run)
85
+
86
+ from scripts.data.chunk_reviews import chunk_reviews
87
+ self._run_step("Chunk reviews", chunk_reviews,
88
+ str(self.data_dir / "review_highlights.txt"),
89
+ str(self.data_dir / "review_chunks.jsonl"),
90
+ )
91
+
92
+ def run_rec_preparation(self, stage: str = "all") -> None:
93
+ """Stage 2: RecSys data preparation."""
94
+ if stage not in ("all", "rec"):
95
+ return
96
+
97
+ from scripts.data.split_rec_data import run as split_run
98
+ self._run_step("Split train/val/test", split_run,
99
+ data_path=self.data_dir / "raw" / "Books_rating.csv",
100
+ output_dir=self.rec_dir,
101
+ )
102
+
103
+ from scripts.data.build_sequences import build_sequences
104
+ self._run_step("Build user sequences", build_sequences, str(self.rec_dir))
105
+
106
+ def run_index_building(self, stage: str = "all") -> None:
107
+ """Stage 3: Index building."""
108
+ if stage not in ("all", "index") or self.skip_index:
109
+ return
110
+
111
+ from scripts.init_sqlite_db import init_sqlite_db
112
+ self._run_step("Build SQLite metadata (books.db)", init_sqlite_db, str(self.data_dir))
113
+
114
+ from scripts.data.init_dual_index import init_chunk_index
115
+ self._run_step("Build chunk vector index", init_chunk_index)
116
+
117
+ def run_training(self, stage: str = "all") -> None:
118
+ """Stage 4: Model training via src imports."""
119
+ if stage not in ("all", "models") or self.skip_models:
120
+ return
121
+
122
+ train_path = self.rec_dir / "train.csv"
123
+ if not train_path.exists():
124
+ logger.warning("train.csv not found, skipping model training")
125
+ return
126
+
127
+ import pandas as pd
128
+ df = pd.read_csv(train_path)
129
+
130
+ from src.recall.itemcf import ItemCF
131
+ self._run_step("Train ItemCF", lambda: ItemCF().fit(df))
132
+
133
+ from src.recall.usercf import UserCF
134
+ self._run_step("Train UserCF", lambda: UserCF().fit(df))
135
+
136
+ from src.recall.swing import Swing
137
+ self._run_step("Train Swing", lambda: Swing().fit(df))
138
+
139
+ from src.recall.popularity import PopularityRecall
140
+ self._run_step("Train Popularity", lambda: PopularityRecall().fit(df))
141
+
142
+ from src.recall.item2vec import Item2Vec
143
+ self._run_step("Train Item2Vec", lambda: Item2Vec().fit(df))
144
+
145
+ from src.recall.embedding import YoutubeDNNRecall
146
+ self._run_step("Train YoutubeDNN", lambda: YoutubeDNNRecall().fit(
147
+ df, books_path=self.data_dir / "books_processed.csv"
148
+ ))
149
+
150
+ from src.recall.sasrec_recall import SASRecRecall
151
+ self._run_step("Train SASRec", lambda: SASRecRecall().fit(df))
152
+
153
+ from scripts.model.train_ranker import train_ranker, train_stacking
154
+ self._run_step("Train Ranker", train_stacking if self.stacking else train_ranker)
155
+
156
+ def run_evaluation(self) -> None:
157
+ """Stage 5: Validation."""
158
+ def _validate():
159
+ from scripts.data.validate_data import (
160
+ validate_raw, validate_processed, validate_rec,
161
+ validate_index, validate_models,
162
+ )
163
+ validate_raw()
164
+ validate_processed()
165
+ validate_rec()
166
+ validate_index()
167
+ validate_models()
168
+
169
+ self._run_step("Validate pipeline", _validate)
170
+
171
+ def run(self, stage: str = "all") -> None:
172
+ """Execute full pipeline: Data Cleaning -> Training -> Evaluation."""
173
+ logger.info("=" * 60)
174
+ logger.info("Pipeline: Data Cleaning -> Training -> Evaluation")
175
+ logger.info("=" * 60)
176
+
177
+ start_total = time.time()
178
+
179
+ self.run_data_cleaning(stage)
180
+ self.run_rec_preparation(stage)
181
+ self.run_index_building(stage)
182
+ self.run_training(stage)
183
+ self.run_evaluation()
184
+
185
+ elapsed = time.time() - start_total
186
+ logger.info("=" * 60)
187
+ logger.info("Pipeline completed in %.1f min", elapsed / 60)
188
+ logger.info("=" * 60)
189
 
190
 
191
  def main():
192
+ parser = argparse.ArgumentParser(description="Run data pipeline (no subprocess)")
193
+ parser.add_argument(
194
+ "--stage",
195
+ choices=["all", "books", "rec", "index", "models"],
196
+ default="all",
197
+ help="Which stage to run",
198
+ )
199
  parser.add_argument("--skip-models", action="store_true", help="Skip model training")
200
  parser.add_argument("--skip-index", action="store_true", help="Skip index building")
201
  parser.add_argument("--validate-only", action="store_true", help="Only run validation")
202
+ parser.add_argument("--device", default=None, help="Device for ML (cpu/cuda/mps)")
203
+ parser.add_argument("--stacking", action="store_true", help="Enable stacking ranker")
204
  args = parser.parse_args()
205
+
 
 
 
 
206
  if args.validate_only:
207
+ logger.info("Validation only")
208
+ Pipeline().run_evaluation()
209
  return
210
+
211
+ pipeline = Pipeline(
212
+ device=args.device,
213
+ skip_models=args.skip_models,
214
+ skip_index=args.skip_index,
215
+ stacking=args.stacking,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  )
217
+ pipeline.run(stage=args.stage)
 
 
 
 
218
 
219
 
220
  if __name__ == "__main__":
src/core/model_loader.py CHANGED
@@ -13,7 +13,7 @@ Usage:
13
  import os
14
  import logging
15
  from pathlib import Path
16
- from huggingface_hub import hf_hub_download, snapshot_download
17
  from src.config import DATA_DIR
18
 
19
  logger = logging.getLogger(__name__)
 
13
  import os
14
  import logging
15
  from pathlib import Path
16
+ from huggingface_hub import snapshot_download
17
  from src.config import DATA_DIR
18
 
19
  logger = logging.getLogger(__name__)
src/data/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Data access layer for book recommendation system."""
2
+
3
+ from src.data.repository import DataRepository, data_repository
4
+
5
+ __all__ = ["DataRepository", "data_repository"]
src/data/repository.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Unified Data Repository for book recommendation system.
3
+
4
+ Centralizes all core data access: books metadata, user history, etc.
5
+ Replaces scattered pandas.read_csv and pickle.load calls across services.
6
+ """
7
+
8
+ from pathlib import Path
9
+ from typing import Any, Dict, List, Optional
10
+ import sqlite3
11
+
12
+ from src.config import DATA_DIR
13
+ from src.core.metadata_store import metadata_store
14
+ from src.utils import setup_logger
15
+
16
+ logger = setup_logger(__name__)
17
+
18
+ # Core data file paths
19
+ BOOKS_DB_PATH = DATA_DIR / "books.db"
20
+ BOOKS_PROCESSED_CSV = DATA_DIR / "books_processed.csv"
21
+ RECALL_MODELS_DB = DATA_DIR / "recall_models.db"
22
+
23
+
24
+ class DataRepository:
25
+ """
26
+ Singleton data access layer. Manages loading of books_processed.csv,
27
+ books.db, recall_models.db (user_history), etc.
28
+ """
29
+
30
+ _instance: Optional["DataRepository"] = None
31
+
32
+ def __new__(cls) -> "DataRepository":
33
+ if cls._instance is None:
34
+ cls._instance = super(DataRepository, cls).__new__(cls)
35
+ cls._instance._initialized = False
36
+ return cls._instance
37
+
38
+ def __init__(self) -> None:
39
+ if getattr(self, "_initialized", False):
40
+ return
41
+ self._initialized = True
42
+ self._recall_conn: Optional[sqlite3.Connection] = None
43
+ logger.info("DataRepository: Initialized (singleton)")
44
+
45
+ def _get_recall_connection(self) -> Optional[sqlite3.Connection]:
46
+ """Lazy SQLite connection for recall_models.db."""
47
+ if self._recall_conn is None:
48
+ if not RECALL_MODELS_DB.exists():
49
+ logger.warning(f"recall_models.db not found at {RECALL_MODELS_DB}")
50
+ return None
51
+ try:
52
+ self._recall_conn = sqlite3.connect(
53
+ str(RECALL_MODELS_DB), check_same_thread=False
54
+ )
55
+ except sqlite3.Error as e:
56
+ logger.error(f"DataRepository: Failed to connect to recall DB: {e}")
57
+ return self._recall_conn
58
+
59
+ def get_book_metadata(self, isbn: str) -> Optional[Dict[str, Any]]:
60
+ """
61
+ Get book metadata by ISBN.
62
+
63
+ Uses MetadataStore (books.db) as primary source. Returns None if not found.
64
+ """
65
+ meta = metadata_store.get_book_metadata(str(isbn))
66
+ return meta if meta else None
67
+
68
+ def get_user_history(self, user_id: str) -> List[str]:
69
+ """
70
+ Get user's interaction history (ISBNs) from recall_models.db.
71
+
72
+ Used by recommendation algorithms (ItemCF, etc.). Returns empty list if
73
+ DB unavailable or user has no history.
74
+ """
75
+ conn = self._get_recall_connection()
76
+ if not conn:
77
+ return []
78
+ try:
79
+ cursor = conn.cursor()
80
+ cursor.execute(
81
+ "SELECT isbn FROM user_history WHERE user_id = ?", (user_id,)
82
+ )
83
+ return [row[0] for row in cursor.fetchall()]
84
+ except sqlite3.Error as e:
85
+ logger.error(f"DataRepository: get_user_history failed: {e}")
86
+ return []
87
+
88
+ def get_all_categories(self) -> List[str]:
89
+ """Get unique book categories. Delegates to MetadataStore."""
90
+ return metadata_store.get_all_categories()
91
+
92
+
93
+ # Global singleton instance
94
+ data_repository = DataRepository()
src/init_db.py CHANGED
@@ -1,7 +1,6 @@
1
  import os
2
  import shutil
3
  import sys
4
- import torch
5
  from pathlib import Path
6
 
7
  # Add project root to Python path
@@ -21,20 +20,10 @@ def init_db():
21
  # FIX: Disable Tokenizers Parallelism to prevent deadlocks on macOS
22
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
23
 
24
- # Force CPU for data ingestion to avoid MPS (Metal) async hangs during long processing
25
- # We only need speed for inference, reliability is key for building the DB.
26
  device = "cpu"
27
  print("🐢 Forcing CPU for stable database ingestion (prevents macOS Freezes).")
28
-
29
- # if torch.backends.mps.is_available():
30
- # device = "mps"
31
- # print("⚡️ MacOS GPU (MPS) Detected! switching to GPU acceleration.")
32
- # elif torch.cuda.is_available():
33
- # device = "cuda"
34
- # print("⚡️ NVIDIA GPU (CUDA) Detected!")
35
- # else:
36
- # device = "cpu"
37
- # print("🐢 No GPU detected, running on CPU (this might be slow).")
38
 
39
  # 1. Clear existing DB if any (to avoid duplicates/corruption)
40
  if CHROMA_DB_DIR.exists():
 
1
  import os
2
  import shutil
3
  import sys
 
4
  from pathlib import Path
5
 
6
  # Add project root to Python path
 
20
  # FIX: Disable Tokenizers Parallelism to prevent deadlocks on macOS
21
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
22
 
23
+ # Force CPU for data ingestion to avoid MPS (Metal) async hangs during long processing.
24
+ # Reliability is key for building the DB; GPU acceleration is only needed for inference.
25
  device = "cpu"
26
  print("🐢 Forcing CPU for stable database ingestion (prevents macOS Freezes).")
 
 
 
 
 
 
 
 
 
 
27
 
28
  # 1. Clear existing DB if any (to avoid duplicates/corruption)
29
  if CHROMA_DB_DIR.exists():
src/main.py CHANGED
@@ -15,8 +15,6 @@ from src.user.profile_store import (
15
  update_book_rating, update_reading_status, update_book_comment,
16
  get_favorites_with_metadata, get_reading_stats
17
  )
18
- from src.marketing.persona import build_persona
19
- from src.marketing.highlights import generate_highlights
20
  from src.api.chat import router as chat_router # ✨ NEW
21
  from src.services.chat_service import chat_service # ✨ NEW
22
  from src.services.recommend_service import RecommendationService # ✨ NEW
@@ -236,9 +234,6 @@ async def favorites_list(user_id: str):
236
  favorites_meta = get_favorites_with_metadata(user_id)
237
  # ENGINEERING IMPROVEMENT: Zero-RAM Lookup
238
  from src.core.metadata_store import metadata_store
239
-
240
- results = []
241
- # Lazy load fetcher (Handled inside utils now)
242
  from src.utils import enrich_book_metadata
243
 
244
  results = []
 
15
  update_book_rating, update_reading_status, update_book_comment,
16
  get_favorites_with_metadata, get_reading_stats
17
  )
 
 
18
  from src.api.chat import router as chat_router # ✨ NEW
19
  from src.services.chat_service import chat_service # ✨ NEW
20
  from src.services.recommend_service import RecommendationService # ✨ NEW
 
234
  favorites_meta = get_favorites_with_metadata(user_id)
235
  # ENGINEERING IMPROVEMENT: Zero-RAM Lookup
236
  from src.core.metadata_store import metadata_store
 
 
 
237
  from src.utils import enrich_book_metadata
238
 
239
  results = []
src/marketing/persona.py CHANGED
@@ -1,15 +1,17 @@
1
  from collections import Counter
2
- from typing import Dict, List, Any
3
- import pandas as pd
4
 
5
  from src.utils import setup_logger
6
 
7
  logger = setup_logger(__name__)
8
 
9
 
10
- def build_persona(fav_isbns: List[str], books: pd.DataFrame) -> Dict[str, Any]:
11
- """Aggregate a simple persona from favorites: top authors and categories."""
12
- if not isinstance(books, pd.DataFrame) or books.empty or not fav_isbns:
 
 
 
13
  return {
14
  "summary": "No profile yet. Start by adding your favorite books to see personalized recommendations.",
15
  "top_authors": [],
 
1
  from collections import Counter
2
+ from typing import Dict, List, Any, Optional
 
3
 
4
  from src.utils import setup_logger
5
 
6
  logger = setup_logger(__name__)
7
 
8
 
9
+ def build_persona(fav_isbns: List[str], books: Optional[Any] = None) -> Dict[str, Any]:
10
+ """
11
+ Aggregate a simple persona from favorites: top authors and categories.
12
+ Uses MetadataStore for lookups; the books param is legacy and unused.
13
+ """
14
+ if not fav_isbns:
15
  return {
16
  "summary": "No profile yet. Start by adding your favorite books to see personalized recommendations.",
17
  "top_authors": [],
src/marketing/personalized_highlight.py CHANGED
@@ -1,5 +1,4 @@
1
  import json
2
- import pandas as pd
3
  from src.marketing.persona import build_persona
4
  from src.marketing.highlights import generate_highlights
5
 
 
1
  import json
 
2
  from src.marketing.persona import build_persona
3
  from src.marketing.highlights import generate_highlights
4
 
src/marketing/verify_p3.py CHANGED
@@ -7,7 +7,7 @@ import torch
7
  from transformers import AutoModelForCausalLM, AutoTokenizer
8
  from peft import PeftModel
9
  from modelscope import snapshot_download
10
- from guardrails import ContentGuardrail
11
 
12
  # Config
13
  BASE_MODEL_ID = "qwen/Qwen2-7B-Instruct"
 
7
  from transformers import AutoModelForCausalLM, AutoTokenizer
8
  from peft import PeftModel
9
  from modelscope import snapshot_download
10
+ from src.marketing.guardrails import ContentGuardrail
11
 
12
  # Config
13
  BASE_MODEL_ID = "qwen/Qwen2-7B-Instruct"
src/recall/embedding.py CHANGED
@@ -5,16 +5,63 @@ V2.7: Replaced torch.matmul brute-force search with Faiss IndexFlatIP
5
  for SIMD-accelerated inner-product retrieval.
6
  """
7
 
8
- import torch
9
- import numpy as np
10
  import pickle
11
  import logging
12
- import faiss
13
  from pathlib import Path
 
 
 
 
 
 
 
 
 
 
 
 
14
  from src.recall.youtube_dnn import YoutubeDNN
15
 
16
  logger = logging.getLogger(__name__)
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  class YoutubeDNNRecall:
19
  def __init__(self, data_dir='data/rec', model_dir='data/model/recall'):
20
  self.data_dir = Path(data_dir)
@@ -33,7 +80,117 @@ class YoutubeDNNRecall:
33
  self.id_to_item = {}
34
  self.meta = None
35
 
36
- def load(self):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  try:
38
  logger.info("Loading YoutubeDNN model...")
39
  # Load metadata
 
5
  for SIMD-accelerated inner-product retrieval.
6
  """
7
 
 
 
8
  import pickle
9
  import logging
 
10
  from pathlib import Path
11
+ from typing import Optional
12
+
13
+ import faiss
14
+ import numpy as np
15
+ import pandas as pd
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.optim as optim
19
+ from torch.utils.data import Dataset, DataLoader
20
+ from tqdm import tqdm
21
+
22
+ from src.recall.sequence_utils import build_sequences_from_df
23
  from src.recall.youtube_dnn import YoutubeDNN
24
 
25
  logger = logging.getLogger(__name__)
26
 
27
+
28
+ class _RetrievalDataset(Dataset):
29
+ """Internal dataset for YoutubeDNN training (history -> target)."""
30
+
31
+ def __init__(self, user_seqs: dict, item_to_cate: dict, default_cate: int, max_history: int):
32
+ self.samples: list[tuple[list[int], int, int]] = []
33
+ self.item_to_cate = item_to_cate
34
+ self.default_cate = default_cate
35
+ self.max_history = max_history
36
+
37
+ for user, seq in user_seqs.items():
38
+ if len(seq) < 3:
39
+ continue
40
+ train_seq = seq[:-2]
41
+ for i in range(1, len(train_seq)):
42
+ target = train_seq[i]
43
+ history = train_seq[:i]
44
+ if len(history) > max_history:
45
+ history = history[-max_history:]
46
+ target_cate = item_to_cate.get(target, default_cate)
47
+ self.samples.append((history, target, target_cate))
48
+
49
+ def __len__(self) -> int:
50
+ return len(self.samples)
51
+
52
+ def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
53
+ history, target, target_cate = self.samples[idx]
54
+ padded = np.zeros(self.max_history, dtype=np.int64)
55
+ length = min(len(history), self.max_history)
56
+ if length > 0:
57
+ padded[:length] = history[-length:]
58
+ return (
59
+ torch.LongTensor(padded),
60
+ torch.tensor(target, dtype=torch.long),
61
+ torch.tensor(target_cate, dtype=torch.long),
62
+ )
63
+
64
+
65
  class YoutubeDNNRecall:
66
  def __init__(self, data_dir='data/rec', model_dir='data/model/recall'):
67
  self.data_dir = Path(data_dir)
 
80
  self.id_to_item = {}
81
  self.meta = None
82
 
83
+ def fit(
84
+ self,
85
+ df: pd.DataFrame,
86
+ books_path: Optional[Path] = None,
87
+ epochs: int = 10,
88
+ batch_size: int = 512,
89
+ lr: float = 0.001,
90
+ embed_dim: int = 64,
91
+ max_history: int = 20,
92
+ ) -> "YoutubeDNNRecall":
93
+ """
94
+ Train YoutubeDNN from interaction DataFrame. Builds sequences internally.
95
+
96
+ Args:
97
+ df: [user_id, isbn, timestamp] (timestamp optional)
98
+ books_path: Path to books_processed.csv for categories. If None, uses default.
99
+ epochs, batch_size, lr: Training hyperparameters.
100
+ """
101
+ logger.info("Building sequences from DataFrame...")
102
+ user_seqs, item_map = build_sequences_from_df(df, max_len=50)
103
+
104
+ self.item_map = item_map
105
+ self.id_to_item = {v: k for k, v in item_map.items()}
106
+ vocab_size = len(item_map) + 1
107
+
108
+ # Category map
109
+ cate_map: dict[str, int] = {"<PAD>": 0, "<UNK>": 1}
110
+ item_to_cate: dict[int, int] = {}
111
+ default_cate = 1
112
+
113
+ books_path = Path(books_path) if books_path else self.data_dir.parent / "books_processed.csv"
114
+ if books_path.exists():
115
+ books_df = pd.read_csv(books_path, usecols=["isbn13", "simple_categories"])
116
+ books_df["isbn"] = books_df["isbn13"].astype(str)
117
+ for _, row in books_df.iterrows():
118
+ isbn = str(row["isbn"])
119
+ if isbn in item_map:
120
+ iid = item_map[isbn]
121
+ cates = str(row["simple_categories"]).split(";")
122
+ main_cate = cates[0].strip() if cates else "Unknown"
123
+ if main_cate not in cate_map:
124
+ cate_map[main_cate] = len(cate_map)
125
+ item_to_cate[iid] = cate_map[main_cate]
126
+ default_cate = cate_map.get("Unknown", 1)
127
+
128
+ for iid in range(1, vocab_size):
129
+ if iid not in item_to_cate:
130
+ item_to_cate[iid] = default_cate
131
+
132
+ cate_vocab_size = len(cate_map)
133
+
134
+ user_config = {"vocab_size": vocab_size, "embed_dim": embed_dim, "history_len": max_history}
135
+ item_config = {
136
+ "vocab_size": vocab_size,
137
+ "embed_dim": embed_dim,
138
+ "cate_vocab_size": cate_vocab_size,
139
+ "cate_embed_dim": 32,
140
+ }
141
+ model_config = {"hidden_dims": [128, 64], "dropout": 0.1}
142
+
143
+ dataset = _RetrievalDataset(user_seqs, item_to_cate, default_cate, max_history)
144
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
145
+
146
+ self.model = YoutubeDNN(user_config, item_config, model_config).to(self.device)
147
+ optimizer = optim.Adam(self.model.parameters(), lr=lr)
148
+ criterion = nn.CrossEntropyLoss()
149
+
150
+ logger.info("Training YoutubeDNN...")
151
+ self.model.train()
152
+ for epoch in range(epochs):
153
+ total_loss = 0.0
154
+ steps = 0
155
+ for history, target_item, target_cate in tqdm(dataloader, desc=f"Epoch {epoch+1}"):
156
+ history = history.to(self.device)
157
+ target_item = target_item.to(self.device)
158
+ target_cate = target_cate.to(self.device)
159
+ optimizer.zero_grad()
160
+ user_vec = self.model.user_tower(history)
161
+ item_vec = self.model.item_tower(target_item, target_cate)
162
+ user_vec = nn.functional.normalize(user_vec, p=2, dim=1)
163
+ item_vec = nn.functional.normalize(item_vec, p=2, dim=1)
164
+ logits = torch.matmul(user_vec, item_vec.t()) / 0.1
165
+ labels = torch.arange(len(user_vec)).to(self.device)
166
+ loss = criterion(logits, labels)
167
+ loss.backward()
168
+ optimizer.step()
169
+ total_loss += loss.item()
170
+ steps += 1
171
+ logger.info(f"Epoch {epoch+1} Loss: {total_loss/steps:.4f}")
172
+
173
+ self.save_dir.mkdir(parents=True, exist_ok=True)
174
+ torch.save(self.model.state_dict(), self.model_dir / "youtube_dnn.pt")
175
+ self.meta = {
176
+ "user_config": user_config,
177
+ "item_config": item_config,
178
+ "model_config": model_config,
179
+ "item_to_cate": item_to_cate,
180
+ }
181
+ with open(self.model_dir / "youtube_dnn_meta.pkl", "wb") as f:
182
+ pickle.dump(self.meta, f)
183
+
184
+ self.data_dir.mkdir(parents=True, exist_ok=True)
185
+ with open(self.data_dir / "item_map.pkl", "wb") as f:
186
+ pickle.dump(self.item_map, f)
187
+ with open(self.data_dir / "user_sequences.pkl", "wb") as f:
188
+ pickle.dump(user_seqs, f)
189
+
190
+ logger.info(f"YoutubeDNN saved to {self.model_dir}")
191
+ return self
192
+
193
+ def load(self) -> bool:
194
  try:
195
  logger.info("Loading YoutubeDNN model...")
196
  # Load metadata
src/recall/fusion.py CHANGED
@@ -1,5 +1,7 @@
1
  import logging
2
  from collections import defaultdict
 
 
3
  from src.recall.itemcf import ItemCF
4
  from src.recall.usercf import UserCF
5
  from src.recall.popularity import PopularityRecall
@@ -10,8 +12,43 @@ from src.recall.sasrec_recall import SASRecRecall
10
 
11
  logger = logging.getLogger(__name__)
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  class RecallFusion:
14
- def __init__(self, data_dir='data/rec', model_dir='data/model/recall'):
 
 
 
 
 
 
 
 
 
 
 
15
  self.itemcf = ItemCF(data_dir, model_dir)
16
  self.usercf = UserCF(data_dir, model_dir)
17
  self.popularity = PopularityRecall(data_dir, model_dir)
@@ -21,8 +58,8 @@ class RecallFusion:
21
  self.sasrec = SASRecRecall(data_dir, model_dir)
22
 
23
  self.models_loaded = False
24
-
25
- def load_models(self):
26
  if self.models_loaded:
27
  return
28
 
@@ -35,58 +72,57 @@ class RecallFusion:
35
  self.item2vec.load()
36
  self.sasrec.load()
37
  self.models_loaded = True
38
-
39
- def get_recall_items(self, user_id, history_items=None, k=100):
40
  """
41
- Multi-channel recall fusion using RRF
42
  """
43
  if not self.models_loaded:
44
  self.load_models()
45
-
46
  candidates = defaultdict(float)
47
-
48
- # 1. YoutubeDNN (High weight for potential semantic match)
49
- dnn_recs = self.youtube_dnn.recommend(user_id, history_items, top_k=k)
50
- self._add_to_candidates(candidates, dnn_recs, weight=0.1)
51
-
52
- # 2. ItemCF
53
- icf_recs = self.itemcf.recommend(user_id, history_items, top_k=k)
54
- self._add_to_candidates(candidates, icf_recs, weight=1.0)
55
-
56
- # 3. UserCF
57
- ucf_recs = self.usercf.recommend(user_id, history_items, top_k=k)
58
- self._add_to_candidates(candidates, ucf_recs, weight=1.0)
59
-
60
- # 4. Swing
61
- swing_recs = self.swing.recommend(user_id, history_items, top_k=k)
62
- self._add_to_candidates(candidates, swing_recs, weight=1.0)
63
-
64
- # 5. SASRec Embedding
65
- sas_recs = self.sasrec.recommend(user_id, history_items, top_k=k)
66
- self._add_to_candidates(candidates, sas_recs, weight=1.0)
67
-
68
- # 6. Item2Vec
69
- i2v_recs = self.item2vec.recommend(user_id, history_items, top_k=k)
70
- self._add_to_candidates(candidates, i2v_recs, weight=0.8)
71
-
72
- # 7. Popularity (Filler)
73
- pop_recs = self.popularity.recommend(user_id, top_k=k)
74
- self._add_to_candidates(candidates, pop_recs, weight=0.5)
75
-
76
- # Sort by RRF score
77
  sorted_cands = sorted(candidates.items(), key=lambda x: x[1], reverse=True)
78
  return sorted_cands[:k]
79
-
80
- def _add_to_candidates(self, candidates, recs, weight=1.0, rrf_k=60):
81
  """
82
- Add recommendations to candidate pool using RRF
83
- score += weight * (1 / (k + rank))
84
  """
85
  if not recs:
86
  return
87
-
88
  for rank, (item, score) in enumerate(recs):
89
- rrf_score = weight * (1.0 / (rrf_k + rank + 1))
90
  candidates[item] += rrf_score
91
 
92
  if __name__ == "__main__":
 
1
  import logging
2
  from collections import defaultdict
3
+ from typing import Optional
4
+
5
  from src.recall.itemcf import ItemCF
6
  from src.recall.usercf import UserCF
7
  from src.recall.popularity import PopularityRecall
 
12
 
13
  logger = logging.getLogger(__name__)
14
 
15
+ # Default: only the 3 most effective channels enabled. Others available but off.
16
+ DEFAULT_CHANNEL_CONFIG = {
17
+ "itemcf": {"enabled": True, "weight": 1.0},
18
+ "sasrec": {"enabled": True, "weight": 1.0},
19
+ "youtube_dnn": {"enabled": True, "weight": 1.0},
20
+ "usercf": {"enabled": False, "weight": 1.0},
21
+ "swing": {"enabled": False, "weight": 1.0},
22
+ "item2vec": {"enabled": False, "weight": 0.8},
23
+ "popularity": {"enabled": False, "weight": 0.5},
24
+ }
25
+
26
+
27
+ def _merge_config(default: dict, override: Optional[dict]) -> dict:
28
+ """Deep-merge override into default (shallow per channel)."""
29
+ merged = {k: dict(v) for k, v in default.items()}
30
+ if override:
31
+ for ch, cfg in override.items():
32
+ if ch in merged:
33
+ merged[ch].update(cfg)
34
+ else:
35
+ merged[ch] = dict(cfg)
36
+ return merged
37
+
38
+
39
  class RecallFusion:
40
+ def __init__(
41
+ self,
42
+ data_dir: str = "data/rec",
43
+ model_dir: str = "data/model/recall",
44
+ channel_config: Optional[dict] = None,
45
+ rrf_k: int = 60,
46
+ ):
47
+ self.data_dir = data_dir
48
+ self.model_dir = model_dir
49
+ self.channel_config = _merge_config(DEFAULT_CHANNEL_CONFIG, channel_config)
50
+ self.rrf_k = rrf_k
51
+
52
  self.itemcf = ItemCF(data_dir, model_dir)
53
  self.usercf = UserCF(data_dir, model_dir)
54
  self.popularity = PopularityRecall(data_dir, model_dir)
 
58
  self.sasrec = SASRecRecall(data_dir, model_dir)
59
 
60
  self.models_loaded = False
61
+
62
+ def load_models(self) -> None:
63
  if self.models_loaded:
64
  return
65
 
 
72
  self.item2vec.load()
73
  self.sasrec.load()
74
  self.models_loaded = True
75
+
76
+ def get_recall_items(self, user_id: str, history_items=None, k: int = 100):
77
  """
78
+ Multi-channel recall fusion using RRF. Channels and weights controlled by config.
79
  """
80
  if not self.models_loaded:
81
  self.load_models()
82
+
83
  candidates = defaultdict(float)
84
+ cfg = self.channel_config
85
+
86
+ if cfg.get("youtube_dnn", {}).get("enabled", False):
87
+ recs = self.youtube_dnn.recommend(user_id, history_items, top_k=k)
88
+ self._add_to_candidates(candidates, recs, cfg["youtube_dnn"]["weight"])
89
+
90
+ if cfg.get("itemcf", {}).get("enabled", False):
91
+ recs = self.itemcf.recommend(user_id, history_items, top_k=k)
92
+ self._add_to_candidates(candidates, recs, cfg["itemcf"]["weight"])
93
+
94
+ if cfg.get("usercf", {}).get("enabled", False):
95
+ recs = self.usercf.recommend(user_id, history_items, top_k=k)
96
+ self._add_to_candidates(candidates, recs, cfg["usercf"]["weight"])
97
+
98
+ if cfg.get("swing", {}).get("enabled", False):
99
+ recs = self.swing.recommend(user_id, history_items, top_k=k)
100
+ self._add_to_candidates(candidates, recs, cfg["swing"]["weight"])
101
+
102
+ if cfg.get("sasrec", {}).get("enabled", False):
103
+ recs = self.sasrec.recommend(user_id, history_items, top_k=k)
104
+ self._add_to_candidates(candidates, recs, cfg["sasrec"]["weight"])
105
+
106
+ if cfg.get("item2vec", {}).get("enabled", False):
107
+ recs = self.item2vec.recommend(user_id, history_items, top_k=k)
108
+ self._add_to_candidates(candidates, recs, cfg["item2vec"]["weight"])
109
+
110
+ if cfg.get("popularity", {}).get("enabled", False):
111
+ recs = self.popularity.recommend(user_id, top_k=k)
112
+ self._add_to_candidates(candidates, recs, cfg["popularity"]["weight"])
113
+
114
  sorted_cands = sorted(candidates.items(), key=lambda x: x[1], reverse=True)
115
  return sorted_cands[:k]
116
+
117
+ def _add_to_candidates(self, candidates, recs, weight: float) -> None:
118
  """
119
+ Add recommendations to candidate pool using RRF.
120
+ score += weight * (1 / (rrf_k + rank + 1))
121
  """
122
  if not recs:
123
  return
 
124
  for rank, (item, score) in enumerate(recs):
125
+ rrf_score = weight * (1.0 / (self.rrf_k + rank + 1))
126
  candidates[item] += rrf_score
127
 
128
  if __name__ == "__main__":
src/recall/itemcf.py CHANGED
@@ -1,37 +1,150 @@
1
- import pickle
2
  import math
3
- import numpy as np
4
- import pandas as pd
5
- from tqdm import tqdm
6
  from collections import defaultdict
7
  from pathlib import Path
 
 
 
 
 
8
  import logging
9
 
10
  logger = logging.getLogger(__name__)
11
 
 
 
 
 
 
12
  class ItemCF:
13
  """
14
  Item-based Collaborative Filtering.
15
-
16
- ENGINEERING IMPROVEMENT:
17
- Transitioned from loading a 7GB+ in-memory similarity matrix (pickle) to an
18
- indexed SQLite database (`recall_models.db`). Candidate generation is now
19
- offloaded to highly efficient SQL aggregations.
20
-
21
- This change ensures zero-RAM loading for the similarity matrix while maintaining
22
- 100% mathematical parity with the original Python implementation.
23
  """
24
- def __init__(self, data_dir='data/rec', save_dir='data/model/recall'):
 
25
  self.data_dir = Path(data_dir)
26
  self.save_dir = Path(save_dir)
27
- self.db_path = Path("data/recall_models.db")
28
- self.conn = None
29
-
30
- def load(self):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  if self.db_path.exists():
32
- import sqlite3
33
  try:
34
- self.conn = sqlite3.connect(self.db_path, check_same_thread=False)
35
  logger.info(f"ItemCF: Connected to SQLite {self.db_path}")
36
  return True
37
  except Exception as e:
@@ -86,8 +199,6 @@ class ItemCF:
86
  logger.error(f"ItemCF Query Error: {e}")
87
  return []
88
 
89
- def save(self): pass # Migration is done via script
90
- def fit(self, df): pass # Training should be done separately
91
 
92
  if __name__ == "__main__":
93
  # Test run
 
 
1
  import math
2
+ import sqlite3
 
 
3
  from collections import defaultdict
4
  from pathlib import Path
5
+ from typing import Optional
6
+
7
+ import pandas as pd
8
+ from tqdm import tqdm
9
+
10
  import logging
11
 
12
  logger = logging.getLogger(__name__)
13
 
14
+ # Direction weights for asymmetric co-occurrence (CHANGELOG: forward=1.0, backward=0.7)
15
+ FORWARD_WEIGHT = 1.0
16
+ BACKWARD_WEIGHT = 0.7
17
+
18
+
19
  class ItemCF:
20
  """
21
  Item-based Collaborative Filtering.
22
+
23
+ Co-occurrence similarity with direction weight: when user reads item A then B,
24
+ sim(A,B) += 1.0 (forward), sim(B,A) += 0.7 (backward). This captures temporal
25
+ "read-after" patterns.
26
+
27
+ Persists to SQLite (recall_models.db) for zero-RAM inference.
 
 
28
  """
29
+
30
+ def __init__(self, data_dir: str = "data/rec", save_dir: str = "data/model/recall"):
31
  self.data_dir = Path(data_dir)
32
  self.save_dir = Path(save_dir)
33
+ self.save_dir.mkdir(parents=True, exist_ok=True)
34
+ self.db_path = self.data_dir.parent / "recall_models.db"
35
+ self.conn: Optional[sqlite3.Connection] = None
36
+
37
+ def fit(self, df: pd.DataFrame, top_k_sim: int = 200) -> "ItemCF":
38
+ """
39
+ Build co-occurrence similarity matrix with direction weight, then persist to SQLite.
40
+
41
+ Args:
42
+ df: DataFrame with columns [user_id, isbn, rating, timestamp].
43
+ If timestamp is missing, assumes row order per user.
44
+ top_k_sim: Keep only top-k similar items per item to reduce size.
45
+ """
46
+ logger.info("Building ItemCF similarity matrix (direction-weighted co-occurrence)...")
47
+
48
+ # 1. Build per-user chronologically ordered item sequences
49
+ user_seqs: dict[str, list[tuple[str, float]]] = defaultdict(list)
50
+ if "timestamp" in df.columns:
51
+ for _, row in tqdm(df.iterrows(), total=len(df), desc="Building user sequences"):
52
+ user_seqs[row["user_id"]].append((str(row["isbn"]), float(row["timestamp"])))
53
+ for uid in user_seqs:
54
+ user_seqs[uid] = [x[0] for x in sorted(user_seqs[uid], key=lambda t: t[1])]
55
+ else:
56
+ for _, row in tqdm(df.iterrows(), total=len(df), desc="Building user sequences"):
57
+ user_seqs[row["user_id"]].append(str(row["isbn"]))
58
+
59
+ user_hist = {u: list(items) for u, items in user_seqs.items()}
60
+
61
+ # 2. Count users per item (for cosine normalization)
62
+ item_users: dict[str, set[str]] = defaultdict(set)
63
+ for user_id, items in user_hist.items():
64
+ for item in items:
65
+ item_users[item].add(user_id)
66
+ item_counts = {k: len(v) for k, v in item_users.items()}
67
+
68
+ # 3. Build item-item co-occurrence with direction weight
69
+ sim: dict[str, dict[str, float]] = defaultdict(lambda: defaultdict(float))
70
+ for user_id, items in tqdm(user_hist.items(), desc="Computing co-occurrence"):
71
+ for i in range(len(items)):
72
+ item_i = items[i]
73
+ for j in range(i + 1, len(items)):
74
+ item_j = items[j]
75
+ # Forward: i before j -> sim(i,j) += 1.0
76
+ sim[item_i][item_j] += FORWARD_WEIGHT
77
+ # Backward: j after i -> sim(j,i) += 0.7
78
+ sim[item_j][item_i] += BACKWARD_WEIGHT
79
+
80
+ # 4. Normalize by sqrt(|N_i| * |N_j|) (cosine-style)
81
+ logger.info("Normalizing ItemCF matrix...")
82
+ final_sim: dict[str, dict[str, float]] = {}
83
+ for item_i, related in tqdm(sim.items(), desc="Normalizing"):
84
+ ni = item_counts.get(item_i, 1)
85
+ pruned = sorted(related.items(), key=lambda x: x[1], reverse=True)[:top_k_sim]
86
+ final_sim[item_i] = {}
87
+ for item_j, raw_score in pruned:
88
+ nj = item_counts.get(item_j, 1)
89
+ norm = math.sqrt(ni * nj)
90
+ if norm > 0:
91
+ final_sim[item_i][item_j] = raw_score / norm
92
+
93
+ self._sim_matrix = final_sim
94
+ self._user_hist = user_hist
95
+ self.save()
96
+ logger.info(f"ItemCF built: {len(final_sim)} items, saved to {self.db_path}")
97
+ return self
98
+
99
+ def save(self) -> None:
100
+ """Persist similarity matrix and user history to SQLite."""
101
+ if not hasattr(self, "_sim_matrix") or not hasattr(self, "_user_hist"):
102
+ logger.warning("ItemCF.save: No fitted model to save.")
103
+ return
104
+
105
+ self.db_path.parent.mkdir(parents=True, exist_ok=True)
106
+ conn = sqlite3.connect(str(self.db_path))
107
+ cursor = conn.cursor()
108
+
109
+ cursor.execute("DROP TABLE IF EXISTS item_similarity")
110
+ cursor.execute("""
111
+ CREATE TABLE item_similarity (item1 TEXT, item2 TEXT, score REAL)
112
+ """)
113
+ cursor.execute("DROP TABLE IF EXISTS user_history")
114
+ cursor.execute("""
115
+ CREATE TABLE user_history (user_id TEXT, isbn TEXT)
116
+ """)
117
+
118
+ batch = []
119
+ for item1, related in tqdm(self._sim_matrix.items(), desc="Writing item_similarity"):
120
+ for item2, score in related.items():
121
+ batch.append((item1, item2, score))
122
+ if len(batch) >= 100000:
123
+ cursor.executemany("INSERT INTO item_similarity VALUES (?, ?, ?)", batch)
124
+ batch = []
125
+ if batch:
126
+ cursor.executemany("INSERT INTO item_similarity VALUES (?, ?, ?)", batch)
127
+
128
+ batch = []
129
+ for user_id, isbns in tqdm(self._user_hist.items(), desc="Writing user_history"):
130
+ for isbn in isbns:
131
+ batch.append((user_id, isbn))
132
+ if len(batch) >= 100000:
133
+ cursor.executemany("INSERT INTO user_history VALUES (?, ?)", batch)
134
+ batch = []
135
+ if batch:
136
+ cursor.executemany("INSERT INTO user_history VALUES (?, ?)", batch)
137
+
138
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_item1 ON item_similarity(item1)")
139
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_user ON user_history(user_id)")
140
+ conn.commit()
141
+ conn.close()
142
+ logger.info(f"ItemCF saved to {self.db_path}")
143
+
144
+ def load(self) -> bool:
145
  if self.db_path.exists():
 
146
  try:
147
+ self.conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
148
  logger.info(f"ItemCF: Connected to SQLite {self.db_path}")
149
  return True
150
  except Exception as e:
 
199
  logger.error(f"ItemCF Query Error: {e}")
200
  return []
201
 
 
 
202
 
203
  if __name__ == "__main__":
204
  # Test run
src/recall/popularity.py CHANGED
@@ -1,5 +1,4 @@
1
  import pandas as pd
2
- from collections import defaultdict
3
  import pickle
4
  from pathlib import Path
5
  import logging
 
1
  import pandas as pd
 
2
  import pickle
3
  from pathlib import Path
4
  import logging
src/recall/sasrec_recall.py CHANGED
@@ -10,13 +10,52 @@ for SIMD-accelerated approximate nearest neighbor search.
10
 
11
  import pickle
12
  import logging
13
- import numpy as np
14
- import faiss
15
  from pathlib import Path
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  logger = logging.getLogger(__name__)
18
 
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  class SASRecRecall:
21
  def __init__(self, data_dir='data/rec', model_dir='data/model/recall'):
22
  self.data_dir = Path(data_dir)
@@ -30,7 +69,123 @@ class SASRecRecall:
30
  self.faiss_index = None # Faiss IndexFlatIP for fast inner-product search
31
  self.loaded = False
32
 
33
- def load(self):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  try:
35
  logger.info("Loading SASRec recall embeddings...")
36
 
 
10
 
11
  import pickle
12
  import logging
 
 
13
  from pathlib import Path
14
+ from typing import Optional
15
+
16
+ import faiss
17
+ import numpy as np
18
+ import pandas as pd
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.optim as optim
22
+ from torch.utils.data import Dataset, DataLoader
23
+ from tqdm import tqdm
24
+
25
+ from src.model.sasrec import SASRec
26
+ from src.recall.sequence_utils import build_sequences_from_df
27
 
28
  logger = logging.getLogger(__name__)
29
 
30
 
31
+ class _SeqDataset(Dataset):
32
+ """Internal dataset for SASRec training (seq, pos, neg)."""
33
+
34
+ def __init__(self, seqs_dict: dict, num_items: int, max_len: int):
35
+ self.seqs: list[list[int]] = []
36
+ self.num_items = num_items
37
+ self.max_len = max_len
38
+
39
+ for seq in seqs_dict.values():
40
+ if len(seq) < 2:
41
+ continue
42
+ padded = [0] * max_len
43
+ seq_len = min(len(seq), max_len)
44
+ padded[-seq_len:] = seq[-seq_len:]
45
+ self.seqs.append(padded)
46
+ self.seqs = torch.LongTensor(self.seqs)
47
+
48
+ def __len__(self) -> int:
49
+ return len(self.seqs)
50
+
51
+ def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
52
+ seq = self.seqs[idx]
53
+ pos = np.zeros_like(seq.numpy())
54
+ pos[:-1] = seq.numpy()[1:]
55
+ neg = np.random.randint(1, self.num_items + 1, size=len(seq))
56
+ return seq, torch.LongTensor(pos), torch.LongTensor(neg)
57
+
58
+
59
  class SASRecRecall:
60
  def __init__(self, data_dir='data/rec', model_dir='data/model/recall'):
61
  self.data_dir = Path(data_dir)
 
69
  self.faiss_index = None # Faiss IndexFlatIP for fast inner-product search
70
  self.loaded = False
71
 
72
+ def fit(
73
+ self,
74
+ df: pd.DataFrame,
75
+ max_len: int = 50,
76
+ hidden_dim: int = 64,
77
+ epochs: int = 30,
78
+ batch_size: int = 128,
79
+ lr: float = 1e-4,
80
+ ) -> "SASRecRecall":
81
+ """
82
+ Train SASRec from interaction DataFrame. Builds sequences internally.
83
+
84
+ Args:
85
+ df: [user_id, isbn, timestamp] (timestamp optional)
86
+ max_len, hidden_dim, epochs, batch_size, lr: Training hyperparameters.
87
+ """
88
+ logger.info("Building sequences from DataFrame...")
89
+ user_seqs, item_map = build_sequences_from_df(df, max_len=max_len)
90
+
91
+ self.item_map = item_map
92
+ self.id_to_item = {v: k for k, v in item_map.items()}
93
+ num_items = len(item_map)
94
+
95
+ if torch.cuda.is_available():
96
+ device = torch.device("cuda")
97
+ elif torch.backends.mps.is_available():
98
+ device = torch.device("mps")
99
+ else:
100
+ device = torch.device("cpu")
101
+
102
+ logger.info(f"Training SASRec on {device}...")
103
+ dataset = _SeqDataset(user_seqs, num_items, max_len)
104
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
105
+
106
+ model = SASRec(num_items, max_len, hidden_dim).to(device)
107
+ optimizer = optim.Adam(model.parameters(), lr=lr)
108
+ criterion = nn.BCEWithLogitsLoss()
109
+
110
+ model.train()
111
+ for epoch in range(epochs):
112
+ total_loss = 0.0
113
+ pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")
114
+ for seq, pos, neg in pbar:
115
+ seq, pos, neg = seq.to(device), pos.to(device), neg.to(device)
116
+ seq_emb = model(seq)
117
+ mask = pos != 0
118
+ pos_emb = model.item_emb(pos)
119
+ neg_emb = model.item_emb(neg)
120
+ pos_logits = (seq_emb * pos_emb).sum(dim=-1)[mask]
121
+ neg_logits = (seq_emb * neg_emb).sum(dim=-1)[mask]
122
+ pos_labels = torch.ones_like(pos_logits)
123
+ neg_labels = torch.zeros_like(neg_logits)
124
+ loss = criterion(pos_logits, pos_labels) + criterion(neg_logits, neg_labels)
125
+ optimizer.zero_grad()
126
+ loss.backward()
127
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
128
+ optimizer.step()
129
+ total_loss += loss.item()
130
+ pbar.set_postfix(loss=total_loss / (pbar.n + 1))
131
+
132
+ # Save model
133
+ sasrec_dir = self.model_dir.parent / "rec"
134
+ sasrec_dir.mkdir(parents=True, exist_ok=True)
135
+ torch.save(model.state_dict(), sasrec_dir / "sasrec_model.pth")
136
+
137
+ # Extract user embeddings
138
+ logger.info("Extracting user sequence embeddings...")
139
+ model.eval()
140
+ user_emb_dict: dict = {}
141
+ all_users = list(user_seqs.keys())
142
+
143
+ with torch.no_grad():
144
+ for i in tqdm(range(0, len(all_users), batch_size), desc="Embedding users"):
145
+ batch_users = all_users[i : i + batch_size]
146
+ batch_seqs = []
147
+ for u in batch_users:
148
+ s = user_seqs[u]
149
+ padded = [0] * max_len
150
+ seq_len = min(len(s), max_len)
151
+ if seq_len > 0:
152
+ padded[-seq_len:] = s[-seq_len:]
153
+ batch_seqs.append(padded)
154
+ input_tensor = torch.LongTensor(batch_seqs).to(device)
155
+ output = model(input_tensor)
156
+ last_state = output[:, -1, :].cpu().numpy()
157
+ for j, u in enumerate(batch_users):
158
+ user_emb_dict[u] = last_state[j]
159
+
160
+ self.data_dir.mkdir(parents=True, exist_ok=True)
161
+ with open(self.data_dir / "user_seq_emb.pkl", "wb") as f:
162
+ pickle.dump(user_emb_dict, f)
163
+ with open(self.data_dir / "item_map.pkl", "wb") as f:
164
+ pickle.dump(self.item_map, f)
165
+ with open(self.data_dir / "user_sequences.pkl", "wb") as f:
166
+ pickle.dump(user_seqs, f)
167
+
168
+ self.user_seq_emb = user_emb_dict
169
+ self.user_hist = {
170
+ u: set(self.id_to_item[idx] for idx in seq if idx in self.id_to_item)
171
+ for u, seq in user_seqs.items()
172
+ }
173
+ self.item_emb = model.item_emb.weight.detach().cpu().numpy()
174
+ self._build_faiss_index()
175
+ self.loaded = True
176
+
177
+ logger.info(f"SASRec saved to {sasrec_dir}")
178
+ return self
179
+
180
+ def _build_faiss_index(self) -> None:
181
+ """Build Faiss index from item embeddings."""
182
+ if self.item_emb is None:
183
+ return
184
+ dim = self.item_emb.shape[1]
185
+ self.faiss_index = faiss.IndexFlatIP(dim)
186
+ self.faiss_index.add(np.ascontiguousarray(self.item_emb.astype(np.float32)))
187
+
188
+ def load(self) -> bool:
189
  try:
190
  logger.info("Loading SASRec recall embeddings...")
191
 
src/recall/sequence_utils.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Shared utilities for building user sequences from interaction DataFrames.
3
+ Used by SASRec and YoutubeDNN training.
4
+ """
5
+
6
+ from typing import Tuple
7
+
8
+ import pandas as pd
9
+ from tqdm import tqdm
10
+
11
+
12
+ def build_sequences_from_df(
13
+ df: pd.DataFrame, max_len: int = 50
14
+ ) -> Tuple[dict[str, list[int]], dict[str, int]]:
15
+ """
16
+ Build user sequences and item map from interaction DataFrame.
17
+
18
+ Args:
19
+ df: DataFrame with columns [user_id, isbn] and optionally [timestamp].
20
+ max_len: Maximum sequence length (truncate from the left).
21
+
22
+ Returns:
23
+ user_seqs: Dict[user_id, list of item_ids] (1-indexed, 0 is padding)
24
+ item_map: Dict[isbn, item_id]
25
+ """
26
+ items = df["isbn"].astype(str).unique()
27
+ item_map = {isbn: i + 1 for i, isbn in enumerate(items)}
28
+
29
+ user_history: dict[str, list[tuple[str, float]]] = {}
30
+ has_ts = "timestamp" in df.columns
31
+
32
+ for _, row in tqdm(df.iterrows(), total=len(df), desc="Building sequences"):
33
+ u = str(row["user_id"])
34
+ isbn = str(row["isbn"])
35
+ ts = float(row["timestamp"]) if has_ts else 0.0
36
+ if u not in user_history:
37
+ user_history[u] = []
38
+ user_history[u].append((isbn, ts))
39
+
40
+ user_seqs: dict[str, list[int]] = {}
41
+ for u, pairs in user_history.items():
42
+ if has_ts:
43
+ pairs.sort(key=lambda x: x[1])
44
+ item_ids = [item_map.get(isbn, 0) for isbn, _ in pairs]
45
+ item_ids = [x for x in item_ids if x != 0]
46
+ user_seqs[u] = item_ids[-max_len:]
47
+
48
+ return user_seqs, item_map
src/recommender.py CHANGED
@@ -1,13 +1,9 @@
1
- import pandas as pd
2
  from typing import List, Dict, Any
3
- from src.etl import load_books_data
4
  from src.vector_db import VectorDB
5
  from src.config import TOP_K_INITIAL, TOP_K_FINAL, DATA_DIR
6
  from src.cache import CacheManager
7
 
8
- from src.utils import setup_logger, summarize_description
9
- from src.cover_fetcher import fetch_book_cover
10
- from src.marketing.personalized_highlight import get_persona_and_highlights
11
  from src.core.metadata_store import metadata_store
12
 
13
  logger = setup_logger(__name__)
 
 
1
  from typing import List, Dict, Any
 
2
  from src.vector_db import VectorDB
3
  from src.config import TOP_K_INITIAL, TOP_K_FINAL, DATA_DIR
4
  from src.cache import CacheManager
5
 
6
+ from src.utils import setup_logger
 
 
7
  from src.core.metadata_store import metadata_store
8
 
9
  logger = setup_logger(__name__)
src/services/chat_service.py CHANGED
@@ -1,22 +1,22 @@
1
  from typing import Generator, Optional, Dict, Any, List
2
- import pandas as pd
3
  from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, BaseMessage
4
 
5
  from src.core.llm import LLMFactory
6
- from src.etl import load_books_data
7
  from src.marketing.persona import build_persona
8
  from src.user.profile_store import list_favorites
9
  from src.utils import setup_logger
10
 
11
  logger = setup_logger(__name__)
12
 
 
13
  class ChatService:
14
  """
15
  Service for RAG-based chat interaction.
16
  Currently focused on 'Chat with Book' (Single Item Context).
 
17
  """
18
  _instance = None
19
- _books_df = None
20
  _history: Dict[str, List[BaseMessage]] = {}
21
 
22
  def __new__(cls):
@@ -25,25 +25,11 @@ class ChatService:
25
  return cls._instance
26
 
27
  def __init__(self):
28
- # Data is now loaded lazily via _ensure_data
29
  pass
30
 
31
- def _ensure_data(self):
32
- if self._books_df is None:
33
- logger.info("ChatService: Lazy-loading books data for context retrieval...")
34
- self._books_df = load_books_data()
35
-
36
  def _get_book_context(self, isbn: str) -> Optional[Dict[str, Any]]:
37
- """Retrieve full context for a specific book by ISBN."""
38
- self._ensure_data()
39
- # Handle string/int types for ISBN
40
- try:
41
- row = self._books_df[self._books_df["isbn13"].astype(str) == str(isbn)]
42
- if row.empty:
43
- return None
44
- return row.iloc[0].to_dict()
45
- except Exception:
46
- return None
47
 
48
  def _format_book_info(self, book: Dict[str, Any]) -> str:
49
  """Format book metadata into a readable context string."""
@@ -97,8 +83,7 @@ class ChatService:
97
  """
98
  Stream chat response for a specific book.
99
  """
100
- self._ensure_data()
101
- # 1. Fetch Context
102
  book = self._get_book_context(isbn)
103
  if not book:
104
  yield "I'm sorry, I couldn't find the details for this book."
@@ -106,7 +91,7 @@ class ChatService:
106
 
107
  # 2. Build Persona (User Profile)
108
  favs = list_favorites(user_id)
109
- persona_data = build_persona(favs, self._books_df)
110
  user_persona = persona_data.get("summary", "General Reader")
111
 
112
  # 3. Construct Prompt with History
@@ -158,15 +143,11 @@ class ChatService:
158
  yield f"Error generating response: {str(e)}. Please check your API Key."
159
 
160
  def add_book_to_context(self, book_data: Dict[str, Any]):
161
- """Dynamically add a new book to the ChatService context."""
162
- self._ensure_data()
163
- try:
164
- if self._books_df is not None:
165
- new_row_df = pd.DataFrame([book_data])
166
- self._books_df = pd.concat([self._books_df, new_row_df], ignore_index=True)
167
- logger.info(f"ChatService: Added book {book_data.get('isbn13')} to context.")
168
- except Exception as e:
169
- logger.error(f"ChatService: Failed to add book to context: {e}")
170
 
171
  def get_chat_service():
172
  """Helper for lazy access to the ChatService singleton."""
 
1
  from typing import Generator, Optional, Dict, Any, List
 
2
  from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, BaseMessage
3
 
4
  from src.core.llm import LLMFactory
5
+ from src.data.repository import data_repository
6
  from src.marketing.persona import build_persona
7
  from src.user.profile_store import list_favorites
8
  from src.utils import setup_logger
9
 
10
  logger = setup_logger(__name__)
11
 
12
+
13
  class ChatService:
14
  """
15
  Service for RAG-based chat interaction.
16
  Currently focused on 'Chat with Book' (Single Item Context).
17
+ Uses DataRepository for all book metadata lookups.
18
  """
19
  _instance = None
 
20
  _history: Dict[str, List[BaseMessage]] = {}
21
 
22
  def __new__(cls):
 
25
  return cls._instance
26
 
27
  def __init__(self):
 
28
  pass
29
 
 
 
 
 
 
30
  def _get_book_context(self, isbn: str) -> Optional[Dict[str, Any]]:
31
+ """Retrieve full context for a specific book by ISBN via DataRepository."""
32
+ return data_repository.get_book_metadata(str(isbn))
 
 
 
 
 
 
 
 
33
 
34
  def _format_book_info(self, book: Dict[str, Any]) -> str:
35
  """Format book metadata into a readable context string."""
 
83
  """
84
  Stream chat response for a specific book.
85
  """
86
+ # 1. Fetch Context via DataRepository
 
87
  book = self._get_book_context(isbn)
88
  if not book:
89
  yield "I'm sorry, I couldn't find the details for this book."
 
91
 
92
  # 2. Build Persona (User Profile)
93
  favs = list_favorites(user_id)
94
+ persona_data = build_persona(favs)
95
  user_persona = persona_data.get("summary", "General Reader")
96
 
97
  # 3. Construct Prompt with History
 
143
  yield f"Error generating response: {str(e)}. Please check your API Key."
144
 
145
  def add_book_to_context(self, book_data: Dict[str, Any]):
146
+ """
147
+ Called when a new book is added to the system. Book is already in MetadataStore
148
+ via recommender.add_new_book, so no in-memory cache to update. No-op for now.
149
+ """
150
+ logger.info(f"ChatService: Book {book_data.get('isbn13')} added; context served from MetadataStore.")
 
 
 
 
151
 
152
  def get_chat_service():
153
  """Helper for lazy access to the ChatService singleton."""
src/vector_db.py CHANGED
@@ -1,10 +1,7 @@
1
- import gc
2
  from typing import List, Any
3
  # Using community version to avoid 'BaseBlobParser' version conflict in langchain-chroma/core
4
  from langchain_community.vectorstores import Chroma
5
  from langchain_huggingface import HuggingFaceEmbeddings
6
- from langchain_community.document_loaders import TextLoader
7
- from langchain_text_splitters import CharacterTextSplitter
8
  from src.config import REVIEW_HIGHLIGHTS_TXT, CHROMA_DB_DIR, EMBEDDING_MODEL
9
  from src.utils import setup_logger
10
  from src.core.metadata_store import metadata_store
 
 
1
  from typing import List, Any
2
  # Using community version to avoid 'BaseBlobParser' version conflict in langchain-chroma/core
3
  from langchain_community.vectorstores import Chroma
4
  from langchain_huggingface import HuggingFaceEmbeddings
 
 
5
  from src.config import REVIEW_HIGHLIGHTS_TXT, CHROMA_DB_DIR, EMBEDDING_MODEL
6
  from src.utils import setup_logger
7
  from src.core.metadata_store import metadata_store