ymlin105 commited on
Commit
5af0c50
·
1 Parent(s): 3f281f1

chore: update requirements and documentation for intent classifier and RAG evaluation

Browse files
data/rag_golden.README.md ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # RAG Golden Test Set
2
+
3
+ Human-annotated Query-Book pairs for quantitative RAG evaluation.
4
+
5
+ ## Format
6
+
7
+ CSV with columns: `query`, `isbn`, `relevance`, `notes`
8
+
9
+ - **query**: User search string (e.g., "Harry Potter", "0060959479", "books about AI")
10
+ - **isbn**: Expected relevant book ISBN (from your catalog)
11
+ - **relevance**: 1 = relevant (filter rows with relevance=1)
12
+ - **notes**: Optional annotation note
13
+
14
+ Multiple rows per query = multiple relevant books (Recall@K counts all).
15
+
16
+ ## Usage
17
+
18
+ ```bash
19
+ # Copy example and extend with your catalog ISBNs
20
+ cp data/rag_golden.example.csv data/rag_golden.csv
21
+
22
+ # Run evaluation
23
+ python scripts/model/evaluate_rag.py --golden data/rag_golden.csv
24
+ ```
25
+
26
+ ## Metrics
27
+
28
+ - **Accuracy@K**: Fraction of queries with at least one relevant book in top-K
29
+ - **Recall@K**: Fraction of relevant books (across all queries) found in top-K
30
+ - **MRR@K**: Mean reciprocal rank of first relevant hit
31
+
32
+ Target: 500+ pairs for production-quality evaluation.
docs/TECHNICAL_REPORT.md CHANGED
@@ -217,6 +217,8 @@ Architecture: Self-Attentive Sequential Recommendation with Transformer blocks
217
  - Training: 30 epochs, 64-dim embeddings, BCE loss with negative sampling
218
  - Dual use: (1) ranking feature via `sasrec_score`, (2) independent recall channel via embedding dot-product
219
 
 
 
220
  ### 4.3 LGBMRanker (LambdaRank) + Model Stacking
221
 
222
  Replaced XGBoost binary classifier with LightGBM LambdaRank that directly optimizes NDCG. In v2.6.0, a Stacking ensemble (LGBMRanker + XGBClassifier → LogisticRegression meta-learner) further improves ranking robustness.
@@ -226,6 +228,8 @@ Replaced XGBoost binary classifier with LightGBM LambdaRank that directly optimi
226
  - 20K users sampled from 168K validation set for training speed
227
  - 4× negative ratio per positive sample
228
 
 
 
229
  **17 features** in 5 groups:
230
  - User statistics: u_cnt, u_mean, u_std
231
  - Item statistics: i_cnt, i_mean, i_std
@@ -264,10 +268,12 @@ Feature importance (v2.6.0 LGBMRanker, representative subset):
264
  |--------|------------------------|-------------|
265
  | ISBN Recall | 0% | 100% |
266
  | Keyword Precision | Low | High (BM25 boost) |
267
- | Detail Query Recall | 0% | Demonstrated via curated examples (Small-to-Big) |
268
  | Avg Latency | 100ms | 300-800ms |
269
  | Chat Context Limit | ~10 turns | Extended via compression (no formal limit) |
270
 
 
 
271
  ### 5.2 Latency Benchmarks
272
 
273
  | Operation | P50 Latency (Warm) | P95 Latency (Warm) |
@@ -371,7 +377,7 @@ src/
371
 
372
  - **Single-dataset evaluation**: All RecSys metrics are on Amazon Books 200K; no cross-domain or external validation.
373
  - **Rule-based router**: Intent classification uses heuristics (e.g., `len(words) <= 2` for keyword); may not generalize to other domains.
374
- - **RAG evaluation**: RAG quality is demonstrated via curated examples (e.g., "Harry Potter", ISBN recall); no systematic human evaluation or large-scale relevance judgments.
375
  - **Protocol sensitivity**: RecSys metrics can vary with evaluation protocol (e.g., ISBN-only vs title-relaxed matching); see [Experiment Archive](experiments/experiment_archive.md) for discussion.
376
 
377
  ---
 
217
  - Training: 30 epochs, 64-dim embeddings, BCE loss with negative sampling
218
  - Dual use: (1) ranking feature via `sasrec_score`, (2) independent recall channel via embedding dot-product
219
 
220
+ **Time-split (no leakage)**: SASRec is trained on `train.csv` only. `user_seq_emb` and `sas_item_emb` are computed from train-only sequences. When Ranking uses `sasrec_score` for val samples, the user's history contains only train interactions—never val/test. `build_sequences.py` and SASRec/YoutubeDNN all use train-only.
221
+
222
  ### 4.3 LGBMRanker (LambdaRank) + Model Stacking
223
 
224
  Replaced XGBoost binary classifier with LightGBM LambdaRank that directly optimizes NDCG. In v2.6.0, a Stacking ensemble (LGBMRanker + XGBClassifier → LogisticRegression meta-learner) further improves ranking robustness.
 
228
  - 20K users sampled from 168K validation set for training speed
229
  - 4× negative ratio per positive sample
230
 
231
+ **Feature consistency**: Recall models (SASRec, ItemCF, etc.) are trained on train.csv. Ranking labels come from val.csv. Features like `sasrec_score` use train-only embeddings. Pipeline order: `split_rec_data` → `build_sequences` (train-only) → recall models (train) → ranker (val).
232
+
233
  **17 features** in 5 groups:
234
  - User statistics: u_cnt, u_mean, u_std
235
  - Item statistics: i_cnt, i_mean, i_std
 
268
  |--------|------------------------|-------------|
269
  | ISBN Recall | 0% | 100% |
270
  | Keyword Precision | Low | High (BM25 boost) |
271
+ | Detail Query Recall | 0% | Golden Test Set (Accuracy@K, Recall@K, MRR@K) |
272
  | Avg Latency | 100ms | 300-800ms |
273
  | Chat Context Limit | ~10 turns | Extended via compression (no formal limit) |
274
 
275
+ **Golden Test Set**: Human-annotated Query-Book pairs (`data/rag_golden.csv`) replace curated examples. Run `python scripts/model/evaluate_rag.py` for Accuracy@K, Recall@K, MRR@K. Extend with ~500+ pairs for production.
276
+
277
  ### 5.2 Latency Benchmarks
278
 
279
  | Operation | P50 Latency (Warm) | P95 Latency (Warm) |
 
377
 
378
  - **Single-dataset evaluation**: All RecSys metrics are on Amazon Books 200K; no cross-domain or external validation.
379
  - **Rule-based router**: Intent classification uses heuristics (e.g., `len(words) <= 2` for keyword); may not generalize to other domains.
380
+ - **RAG evaluation**: Use Golden Test Set (`data/rag_golden.csv`) for Accuracy@K, Recall@K, MRR@K. Extend to 500+ human-annotated Query-Book pairs for production.
381
  - **Protocol sensitivity**: RecSys metrics can vary with evaluation protocol (e.g., ISBN-only vs title-relaxed matching); see [Experiment Archive](experiments/experiment_archive.md) for discussion.
382
 
383
  ---
docs/interview_guide.md CHANGED
@@ -5,51 +5,80 @@
5
  ## 🌟 核心亮点 (Why this project?)
6
 
7
  ### 1. 架构深度 (Architecture Depth)
8
- * **Agentic RAG**: 不仅仅是简单的向量检索,而是引入了**动态路由 (Dynamic Routing)**。系统能根据用户意图(如 ISBN 精确搜索 vs. 模糊语义搜索)自动选择最佳检索策略(BM25, Hybrid, Small-to-Big),展示了对 RAG 系统的精细化控制能力。
9
- * **Stacking Ensemble (模型融合)**: Ranking 阶段,没有止步于单一模型,而是实现了 **LightGBM + XGBoost + Logistic Regression** Stacking 架构。这体现了对机器学习模型偏差与方差的理解,以及追求极致推荐效果的工程态度。
10
- * **Vector Database**: 结合 ChromaDB 实现语义搜索,紧跟当前 LLM + Vector Store 的技术热点。
 
11
 
12
  ### 2. 工程质量 (Engineering Excellence)
13
- * **性能优化 (Performance Optimization)**:
14
- * **问题**: 系统在并发场景下出现卡顿,且推理延迟较高。
15
- * **解决**:
16
- 1. **Async/Await 陷阱**: 发现 FastAPI 的 `async` 路由中运行了 CPU 密集型任务(Pandas 操作),导致 Event Loop 阻塞。即使加上 `await` 也没用,必须去除非 IO 操作的 async 或使用线程池。改为同步 `def` 让 FastAPI自动利用线程池解决。
17
- 2. **向量化重构**: 发现特征生成使用了 Python 原生 `for` 循环。重构为 Numpy/Pandas 的向量化 (Vectorized) 操作,利用 SIMD 指令集优势,将推理速度提升了约 10 倍。
18
- 3. **单例模式**: 引入 `MetadataStore` 单例,避免每次请求重复加载 CSV,显著降低了内存占用和 I/O 开销。
19
- * **可解释性 (Explainability)**: 集成了 **SHAP (SHapley Additive exPlanations)**。不再是推荐系统的“黑盒”,而是能实时给出“为什么推荐这本书”(例如:因为你喜欢作者 X,或者因为主要读这类书),这是区分初级项目和高级项目的重要特征。
 
20
 
21
  ### 3. 完整性 (Completeness)
22
- * **Full Stack**: 前端 (React) + 后端 (FastAPI) + 数据流 (ETL) + 模型训练 (Train Scripts) + 部署 (Docker)。
23
- * **DevOps**: 包含 Dockerfile 和完整构建脚本,具备生产部署能力。
 
24
 
25
  ---
26
 
27
  ## 🗣️ 面试话术与 Q&A 策略
28
 
29
  ### Q1: 你在项目中遇到的最大困难是什么?怎么解决的?
 
30
  **建议回答**:
 
31
  > “最让我印象深刻的是**系统性能优化**的过程。
32
  > 最初版本在处理高并发请求时,推理延迟很高,甚至会阻塞整个服务。
33
  > 我通过两个层面解决了这个问题:
34
- > 1. **架构层**: 我使用 Profiling 工具发现,FastAPI 的 `async` 接口中包含了大量的 Pandas 数据处理逻辑。因为 Python 的 `async` 是单线程协作式的,CPU 密集型任务会直接卡死 Event Loop。我将其重构为利用 FastAPI 线程池的非异步模式,解决了阻塞问题。
35
- > 2. **代码层**: 我发现特征工程部分原本是用 Python 循环写的。我将其重构为 **Numpy 向量化** 操作,把时间复杂度从 O(N) Python 解释器���销优化到了底层 C 语言级别的矩阵运算,最终将特征生成速度提升了 10 倍以上。”
 
36
 
37
  ### Q2: 为什么选择 Stacking 融合模型?直接用 LightGBM 不够吗?
 
38
  **建议回答**:
 
39
  > “单一模型往往存在局限性。
40
  > LightGBM 擅长处理类别特征和梯度提升,XGBoost 在正则化处理上表现很好。
41
  > 通过 Stacking,我使用一个简单的逻辑回归 (Logistic Regression) 作为 Meta-Learner 来学习这两个强模型的输出。
42
  > 这不仅能利用不同模型的优势(降低 Bias 和 Variance),还能提高系统的**鲁棒性**。在我的离线实验中,Stacking 相比单一 LightGBM 在 NDCG@10 指标上有明显提升。”
43
 
44
  ### Q3: 你的 RAG 系统有什么特别之处?
 
45
  **建议回答**:
 
46
  > “我的 RAG 系统不是简单地 'Retrieve then Generate'。我设计了一个 **Agentic Router**。
47
  > 它会先判断用户的意图:如果是搜书号,直接走精确匹配;如果是模糊描述,走语义索引;如果是复杂查询,会触发 Rerank 重排序。
48
  > 这种动态策略解决了传统 RAG '查得准就不全,查得全就不准' 的痛点。”
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  ---
51
 
52
  ## 📈 关键指标 (Key Metrics)
53
- * **Hit Rate@10**: 0.4545 (v2.6.0, n=2000, Leave-Last-Out)
54
- * **MRR@5**: 0.2893 (Title-relaxed matching)
55
- * **Latency**: P99 < 50ms (Personalized Recs)
 
 
5
  ## 🌟 核心亮点 (Why this project?)
6
 
7
  ### 1. 架构深度 (Architecture Depth)
8
+
9
+ * **Agentic RAG**: 不仅仅是简单的向量检索,而是引入了**动态路由 (Dynamic Routing)**。系统能根据用户意图(如 ISBN 精确搜索 vs. 模糊语义搜索)自动选择最佳检索策略(BM25, Hybrid, Small-to-Big),展示了对 RAG 系统的精细化控制能力。
10
+ * **Stacking Ensemble (模型融合)**: Ranking 阶段,没有止步于单一模型,而是实现了 **LightGBM + XGBoost + Logistic Regression** 的 Stacking 架构。这体现了对机器学习模型偏差与方差的理解,以及追求极致推荐效果的工程态度。
11
+ * **Vector Database**: 结合 ChromaDB 实现语义搜索,紧跟当前 LLM + Vector Store 的技术热点。
12
 
13
  ### 2. 工程质量 (Engineering Excellence)
14
+
15
+ * **性能优化 (Performance Optimization)**:
16
+ * **问题**: 系统在并发场景下出现卡顿,且推理延迟较高。
17
+ * **解决**:
18
+ 1. **Async/Await 陷阱**: 发现 FastAPI 的 `async` 路由中运行了 CPU 密集型任务(Pandas 操作),导致 Event Loop 阻塞。即使加上 `await` 也没用,必须去除非 IO 操作的 async 或使用线程池。改为同步 `def` 让 FastAPI自动利用线程池解决。
19
+ 2. **向量化重构**: 发现特征生成使用了 Python 原生 `for` 循环。重构为 Numpy/Pandas 的向量化 (Vectorized) 操作,利用 SIMD 指令集优势,将推理速度提升了约 10 倍。
20
+ 3. **单例模式**: 引入 `MetadataStore` 单例,避免每次请求重复加载 CSV,显著降低了内存占用和 I/O 开销。
21
+ * **可解释性 (Explainability)**: 集成了 **SHAP (SHapley Additive exPlanations)**。不再是推荐系统的“黑盒”,而是能实时给出“为什么推荐这本书”(例如:因为你喜欢作者 X,或者因为主要读这类书),这是区分初级项目和高级项目的重要特征。
22
 
23
  ### 3. 完整性 (Completeness)
24
+
25
+ * **Full Stack**: 前端 (React) + 后端 (FastAPI) + 数据流 (ETL) + 模型训练 (Train Scripts) + 部署 (Docker)。
26
+ * **DevOps**: 包含 Dockerfile 和完整构建脚本,具备生产部署能力。
27
 
28
  ---
29
 
30
  ## 🗣️ 面试话术与 Q&A 策略
31
 
32
  ### Q1: 你在项目中遇到的最大困难是什么?怎么解决的?
33
+
34
  **建议回答**:
35
+
36
  > “最让我印象深刻的是**系统性能优化**的过程。
37
  > 最初版本在处理高并发请求时,推理延迟很高,甚至会阻塞整个服务。
38
  > 我通过两个层面解决了这个问题:
39
+ >
40
+ > 1. **架构层**: 我使用 Profiling 工具发现,FastAPI `async` 接口中包含了大量的 Pandas 数据处理逻辑。因为 Python `async` 是单线程协作式的,CPU 密集型任务会直接卡死 Event Loop。我将其重构为利用 FastAPI 线程池的非异步模式,解决了阻塞问题。
41
+ > 2. **代码层**: 我发现特征工程部分原本是用 Python 循环写的。我将其重构为 **Numpy 向量化** 操作,把时间复杂度从 O(N) 的 Python 解释器开销优化到了底层 C 语言级别的矩阵运算,最终将特征生成速度提升了 10 倍以上。”
42
 
43
  ### Q2: 为什么选择 Stacking 融合模型?直接用 LightGBM 不够吗?
44
+
45
  **建议回答**:
46
+
47
  > “单一模型往往存在局限性。
48
  > LightGBM 擅长处理类别特征和梯度提升,XGBoost 在正则化处理上表现很好。
49
  > 通过 Stacking,我使用一个简单的逻辑回归 (Logistic Regression) 作为 Meta-Learner 来学习这两个强模型的输出。
50
  > 这不仅能利用不同模型的优势(降低 Bias 和 Variance),还能提高系统的**鲁棒性**。在我的离线实验中,Stacking 相比单一 LightGBM 在 NDCG@10 指标上有明显提升。”
51
 
52
  ### Q3: 你的 RAG 系统有什么特别之处?
53
+
54
  **建议回答**:
55
+
56
  > “我的 RAG 系统不是简单地 'Retrieve then Generate'。我设计了一个 **Agentic Router**。
57
  > 它会先判断用户的意图:如果是搜书号,直接走精确匹配;如果是模糊描述,走语义索引;如果是复杂查询,会触发 Rerank 重排序。
58
  > 这种动态策略解决了传统 RAG '查得准就不全,查得全就不准' 的痛点。”
59
 
60
+
61
+
62
+ **Q1. 关于 Swing 算法的物理意义:**
63
+
64
+ > "我看你用了 Swing 召回。你能直观解释一下,为什么 Swing 比传统的 UserCF 更能抗噪声?`1 / (alpha + |I_u ∩ I_v|)` 这个公式里的分母是在惩罚什么样的用户对?"
65
+ > *(考察点:是否真正理解算法原理,还是只是调包。关键在于理解 Swing 惩罚了那些“原本就很相似”的小圈子用户,突出了 serendipity)*
66
+
67
+ **Q2. 关于 RAG 的延迟优化:**
68
+
69
+ > "你的报告提到 Hybrid Search + Rerank 耗时约 800ms。如果我们要把这个系统部署到抖音的搜索框,要求 P99 延迟在 200ms 以内,你会砍掉哪些环节?或者如何通过工程手段优化?"
70
+ > *(考察点:工程思维。答案可能包括:并行请求、向量库量化 HNSW、Rerank 模型蒸馏、缓存热门 Query、异步加载详情等)*
71
+
72
+ **Q3. SASRec 的应用细节:**
73
+
74
+ > "在 `src/model/sasrec.py` 中,你使用了 Transformer。在推理(Inference)阶段,如果用户每点一本书我们都要刷新推荐,SASRec 的计算成本是很高的。你如何缓存用户的 Embedding 状态以避免每次从头计算整个序列?"
75
+ > *(考察点:对深度学习模型线上推理(Inference)优化的理解。关键在于 KV Cache 或者增量计算)*
76
+ >
77
+
78
  ---
79
 
80
  ## 📈 关键指标 (Key Metrics)
81
+
82
+ * **Hit Rate@10**: 0.4545 (v2.6.0, n=2000, Leave-Last-Out)
83
+ * **MRR@5**: 0.2893 (Title-relaxed matching)
84
+ * **Latency**: P99 < 50ms (Personalized Recs)
requirements.txt CHANGED
@@ -39,6 +39,9 @@ scikit-learn
39
  scipy
40
  requests
41
 
 
 
 
42
  # LLM Agent & Fine-tuning
43
  faiss-cpu
44
  diffusers
 
39
  scipy
40
  requests
41
 
42
+ # Intent classifier backends (optional)
43
+ # fasttext # Uncomment for FastText backend: pip install fasttext
44
+
45
  # LLM Agent & Fine-tuning
46
  faiss-cpu
47
  diffusers
scripts/data/build_sequences.py CHANGED
@@ -4,24 +4,23 @@ Build User Sequences for Sequential Models (SASRec, YoutubeDNN)
4
 
5
  Converts user interaction history into padded sequences for training.
6
 
 
 
 
 
 
7
  Usage:
8
  python scripts/data/build_sequences.py
9
 
10
  Input:
11
- - data/rec/train.csv, val.csv, test.csv
12
 
13
  Output:
14
- - data/rec/user_sequences.pkl (Dict[user_id, List[item_id]])
15
- - data/rec/item_map.pkl (Dict[isbn, item_id])
16
-
17
- Notes:
18
- - Item IDs are 1-indexed (0 is reserved for padding)
19
- - Sequences are truncated to max_len (default: 50)
20
- - Test item is excluded from sequences (used for evaluation)
21
  """
22
 
23
  import pandas as pd
24
- import numpy as np
25
  import pickle
26
  import logging
27
  from pathlib import Path
@@ -30,83 +29,43 @@ from tqdm import tqdm
30
  logging.basicConfig(level=logging.INFO)
31
  logger = logging.getLogger(__name__)
32
 
33
- def build_sequences(data_dir='data/rec', max_len=50):
 
34
  """
35
- Convert user interactions into sequences for SASRec
 
36
  """
37
- logger.info("Building user sequences...")
38
-
39
- # Load all data to map ISBNs to Integers (SASRec needs int IDs)
40
- train_df = pd.read_csv(f'{data_dir}/train.csv')
41
- val_df = pd.read_csv(f'{data_dir}/val.csv')
42
- test_df = pd.read_csv(f'{data_dir}/test.csv')
43
-
44
- full_df = pd.concat([train_df, val_df, test_df])
45
-
46
- # 1. Map ISBN to Index (1-based, 0 is padding)
47
- items = full_df['isbn'].unique()
48
- item_map = {isbn: i+1 for i, isbn in enumerate(items)}
49
- num_items = len(items)
50
-
51
- logger.info(f"Total items: {num_items}")
52
-
53
- # Save map
54
- with open(f'{data_dir}/item_map.pkl', 'wb') as f:
 
 
 
 
 
 
 
 
 
55
  pickle.dump(item_map, f)
56
-
57
- # 2. Group by User and Sort by Time
58
- # Note: Our data split script ALREADY sorted by time.
59
- # But let's be safe. We need original timestamps if possible.
60
- # 'train.csv' doesn't have timestamp column? let me check split_rec_data.
61
- # Ah, split_rec_data removed it. But rows are ordered.
62
- # Actually, we can just group by user_id and assume rows are chronological
63
- # IF we process train -> val -> test order.
64
-
65
- # Let's reconstruct full history per user
66
- logger.info("Grouping user history...")
67
-
68
- # Optimization: processing via dictionary is faster than groupby on large df
69
- user_history = {} # user_id -> list of item_ids
70
-
71
- def process_df(df):
72
- for _, row in tqdm(df.iterrows(), total=len(df), desc="Processing"):
73
- u = row['user_id']
74
- item = item_map[row['isbn']]
75
- if u not in user_history:
76
- user_history[u] = []
77
- user_history[u].append(item)
78
-
79
- # Process in chronological order: Train -> Val -> Test
80
- # Wait, split_rec_data puts last item in test, 2nd last in val.
81
- # So correct order is: Train rows + Val row + Test row.
82
-
83
- # BUT, train.csv has multiple rows per user. They are already sorted by time in split logic.
84
- process_df(train_df)
85
- process_df(val_df)
86
-
87
- # We leave Test item out of the input sequence!
88
- # Test item is the target for evaluation.
89
- # For training SASRec, we use (seq[:-1]) -> predict (seq[1:]).
90
-
91
- # 3. Create Dataset
92
- # Output:
93
- # train_seqs: Dict[user_id, list_of_ints]
94
-
95
- # Pad/Truncate
96
- final_seqs = {}
97
-
98
- for u, history in user_history.items():
99
- # Truncate to max_len
100
- seq = history[-max_len:]
101
- final_seqs[u] = seq
102
-
103
- logger.info(f"Processed {len(final_seqs)} users.")
104
-
105
- # Save sequences
106
- with open(f'{data_dir}/user_sequences.pkl', 'wb') as f:
107
  pickle.dump(final_seqs, f)
108
-
109
- logger.info("Sequence data saved.")
110
 
111
  if __name__ == "__main__":
112
  build_sequences()
 
4
 
5
  Converts user interaction history into padded sequences for training.
6
 
7
+ TIME-SPLIT (strict): Uses train.csv ONLY for sequences and item_map.
8
+ This prevents leakage when Ranking uses SASRec embeddings as features:
9
+ - Val/test samples must not appear in user history when computing sasrec_score.
10
+ - Recall models (SASRec, YoutubeDNN) overwrite these with their own train-only output.
11
+
12
  Usage:
13
  python scripts/data/build_sequences.py
14
 
15
  Input:
16
+ - data/rec/train.csv (val.csv, test.csv exist but are NOT used for sequences)
17
 
18
  Output:
19
+ - data/rec/user_sequences.pkl (Dict[user_id, List[item_id]]) — train-only
20
+ - data/rec/item_map.pkl (Dict[isbn, item_id]) — train-only
 
 
 
 
 
21
  """
22
 
23
  import pandas as pd
 
24
  import pickle
25
  import logging
26
  from pathlib import Path
 
29
  logging.basicConfig(level=logging.INFO)
30
  logger = logging.getLogger(__name__)
31
 
32
+
33
+ def build_sequences(data_dir="data/rec", max_len=50):
34
  """
35
+ Build user sequences from train.csv only (strict time-split).
36
+ Val/test are excluded to avoid leakage in ranking features (sasrec_score).
37
  """
38
+ logger.info("Building user sequences (train-only, time-split)...")
39
+ train_df = pd.read_csv(f"{data_dir}/train.csv")
40
+
41
+ # 1. Item map from train only (matches SASRec/YoutubeDNN)
42
+ items = train_df["isbn"].unique()
43
+ item_map = {isbn: i + 1 for i, isbn in enumerate(items)}
44
+ logger.info(" Items (train): %d", len(item_map))
45
+
46
+ # 2. User history from train only (no val/test)
47
+ user_history = {}
48
+ if "timestamp" in train_df.columns:
49
+ train_df = train_df.sort_values(["user_id", "timestamp"])
50
+ for _, row in tqdm(train_df.iterrows(), total=len(train_df), desc=" Processing"):
51
+ u = str(row["user_id"])
52
+ item = item_map.get(row["isbn"])
53
+ if item is None:
54
+ continue
55
+ if u not in user_history:
56
+ user_history[u] = []
57
+ user_history[u].append(item)
58
+
59
+ final_seqs = {u: hist[-max_len:] for u, hist in user_history.items()}
60
+ logger.info(" Users: %d", len(final_seqs))
61
+
62
+ data_dir = Path(data_dir)
63
+ data_dir.mkdir(parents=True, exist_ok=True)
64
+ with open(data_dir / "item_map.pkl", "wb") as f:
65
  pickle.dump(item_map, f)
66
+ with open(data_dir / "user_sequences.pkl", "wb") as f:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  pickle.dump(final_seqs, f)
68
+ logger.info("Sequence data saved (train-only).")
 
69
 
70
  if __name__ == "__main__":
71
  build_sequences()
scripts/data/fetch_new_books.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Incremental Book Update Script.
4
+
5
+ Fetches recently published books from Google Books API and adds them to the local database.
6
+ Can be run manually or scheduled via cron for periodic updates.
7
+
8
+ Usage:
9
+ python scripts/data/fetch_new_books.py [--categories CATEGORIES] [--year YEAR] [--max MAX]
10
+
11
+ Examples:
12
+ # Fetch new books from current year (default behavior)
13
+ python scripts/data/fetch_new_books.py --categories "fiction" --max 50
14
+
15
+ # Fetch new books across multiple categories
16
+ python scripts/data/fetch_new_books.py --categories "fiction,mystery,science fiction"
17
+
18
+ # Explicitly specify year filter
19
+ python scripts/data/fetch_new_books.py --year 2026 --categories "thriller"
20
+
21
+ # Dry run (show what would be added without actually adding)
22
+ python scripts/data/fetch_new_books.py --dry-run --categories "thriller"
23
+ """
24
+ import argparse
25
+ import sys
26
+ import time
27
+ from pathlib import Path
28
+ from datetime import datetime
29
+ from typing import Optional
30
+
31
+ # Add project root to path
32
+ PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
33
+ sys.path.insert(0, str(PROJECT_ROOT))
34
+
35
+ from src.utils import setup_logger
36
+ from src.core.web_search import search_new_books_by_category, search_google_books
37
+ from src.core.metadata_store import metadata_store
38
+ from src.recommender import BookRecommender
39
+
40
+ logger = setup_logger(__name__)
41
+
42
+ # Default categories to search
43
+ DEFAULT_CATEGORIES = [
44
+ "fiction",
45
+ "mystery",
46
+ "thriller",
47
+ "science fiction",
48
+ "fantasy",
49
+ "romance",
50
+ "biography",
51
+ "history",
52
+ "self-help",
53
+ "business",
54
+ ]
55
+
56
+
57
+ def fetch_trending_books(
58
+ categories: list[str],
59
+ year: Optional[int] = None,
60
+ max_per_category: int = 20,
61
+ dry_run: bool = False,
62
+ ) -> dict:
63
+ """
64
+ Fetch recently published books from Google Books for given categories.
65
+
66
+ Args:
67
+ categories: List of book categories to search
68
+ year: Filter by publication year (default: current year)
69
+ max_per_category: Max books to fetch per category
70
+ dry_run: If True, don't actually add books to database
71
+
72
+ Returns:
73
+ Dict with stats: {added: int, skipped: int, errors: int, books: list}
74
+ """
75
+ if year is None:
76
+ year = datetime.now().year
77
+
78
+ stats = {
79
+ "added": 0,
80
+ "skipped": 0,
81
+ "errors": 0,
82
+ "books": [],
83
+ }
84
+
85
+ recommender = None
86
+ if not dry_run:
87
+ recommender = BookRecommender()
88
+
89
+ for category in categories:
90
+ logger.info(f"Fetching books for category: {category} (year >= {year})")
91
+
92
+ try:
93
+ books = search_new_books_by_category(
94
+ category=category,
95
+ year=year,
96
+ max_results=max_per_category
97
+ )
98
+
99
+ logger.info(f" Found {len(books)} books in '{category}'")
100
+
101
+ for book in books:
102
+ isbn = book.get("isbn13", "")
103
+ if not isbn:
104
+ continue
105
+
106
+ # Check if already exists
107
+ if metadata_store.book_exists(isbn):
108
+ stats["skipped"] += 1
109
+ continue
110
+
111
+ if dry_run:
112
+ logger.info(f" [DRY RUN] Would add: {book.get('title', 'Unknown')} ({isbn})")
113
+ stats["books"].append(book)
114
+ stats["added"] += 1
115
+ else:
116
+ result = recommender.add_new_book(
117
+ isbn=isbn,
118
+ title=book.get("title", ""),
119
+ author=book.get("authors", "Unknown"),
120
+ description=book.get("description", ""),
121
+ category=book.get("simple_categories", category),
122
+ thumbnail=book.get("thumbnail"),
123
+ published_date=book.get("publishedDate", ""),
124
+ )
125
+
126
+ if result:
127
+ stats["added"] += 1
128
+ stats["books"].append(book)
129
+ logger.info(f" Added: {book.get('title', 'Unknown')} ({isbn})")
130
+ else:
131
+ stats["errors"] += 1
132
+
133
+ # Rate limiting: avoid hitting API limits
134
+ time.sleep(0.1)
135
+
136
+ # Pause between categories
137
+ time.sleep(0.5)
138
+
139
+ except Exception as e:
140
+ logger.error(f"Error fetching category '{category}': {e}")
141
+ stats["errors"] += 1
142
+
143
+ return stats
144
+
145
+
146
+ def fetch_by_query(
147
+ queries: list[str],
148
+ max_per_query: int = 20,
149
+ dry_run: bool = False,
150
+ ) -> dict:
151
+ """
152
+ Fetch books by specific search queries (e.g., "AI books 2024", "new thriller novels").
153
+
154
+ Args:
155
+ queries: List of search queries
156
+ max_per_query: Max books per query
157
+ dry_run: If True, don't actually add books
158
+
159
+ Returns:
160
+ Stats dict
161
+ """
162
+ stats = {
163
+ "added": 0,
164
+ "skipped": 0,
165
+ "errors": 0,
166
+ "books": [],
167
+ }
168
+
169
+ recommender = None
170
+ if not dry_run:
171
+ recommender = BookRecommender()
172
+
173
+ for query in queries:
174
+ logger.info(f"Searching: {query}")
175
+
176
+ try:
177
+ books = search_google_books(query, max_results=max_per_query)
178
+ logger.info(f" Found {len(books)} results")
179
+
180
+ for book in books:
181
+ isbn = book.get("isbn13", "")
182
+ if not isbn:
183
+ continue
184
+
185
+ if metadata_store.book_exists(isbn):
186
+ stats["skipped"] += 1
187
+ continue
188
+
189
+ if dry_run:
190
+ logger.info(f" [DRY RUN] Would add: {book.get('title', 'Unknown')}")
191
+ stats["books"].append(book)
192
+ stats["added"] += 1
193
+ else:
194
+ result = recommender.add_new_book(
195
+ isbn=isbn,
196
+ title=book.get("title", ""),
197
+ author=book.get("authors", "Unknown"),
198
+ description=book.get("description", ""),
199
+ category=book.get("simple_categories", "General"),
200
+ thumbnail=book.get("thumbnail"),
201
+ published_date=book.get("publishedDate", ""),
202
+ )
203
+
204
+ if result:
205
+ stats["added"] += 1
206
+ stats["books"].append(book)
207
+ else:
208
+ stats["errors"] += 1
209
+
210
+ time.sleep(0.1)
211
+
212
+ time.sleep(0.5)
213
+
214
+ except Exception as e:
215
+ logger.error(f"Error with query '{query}': {e}")
216
+ stats["errors"] += 1
217
+
218
+ return stats
219
+
220
+
221
+ def print_stats(stats: dict, dry_run: bool = False):
222
+ """Print summary statistics."""
223
+ prefix = "[DRY RUN] " if dry_run else ""
224
+ print(f"\n{prefix}=== Fetch Complete ===")
225
+ print(f" Books added: {stats['added']}")
226
+ print(f" Books skipped: {stats['skipped']} (already in database)")
227
+ print(f" Errors: {stats['errors']}")
228
+
229
+ if stats["books"] and dry_run:
230
+ print(f"\nBooks that would be added:")
231
+ for book in stats["books"][:10]:
232
+ print(f" - {book.get('title', 'Unknown')} by {book.get('authors', 'Unknown')}")
233
+ if len(stats["books"]) > 10:
234
+ print(f" ... and {len(stats['books']) - 10} more")
235
+
236
+
237
+ def main():
238
+ parser = argparse.ArgumentParser(
239
+ description="Fetch new books from Google Books API",
240
+ formatter_class=argparse.RawDescriptionHelpFormatter,
241
+ epilog=__doc__
242
+ )
243
+
244
+ parser.add_argument(
245
+ "--categories",
246
+ type=str,
247
+ default=None,
248
+ help="Comma-separated list of categories (default: all common categories)"
249
+ )
250
+ parser.add_argument(
251
+ "--queries",
252
+ type=str,
253
+ default=None,
254
+ help="Comma-separated list of custom search queries"
255
+ )
256
+ parser.add_argument(
257
+ "--year",
258
+ type=int,
259
+ default=None,
260
+ help="Filter by publication year (default: current year)"
261
+ )
262
+ parser.add_argument(
263
+ "--max",
264
+ type=int,
265
+ default=20,
266
+ help="Max books per category/query (default: 20)"
267
+ )
268
+ parser.add_argument(
269
+ "--dry-run",
270
+ action="store_true",
271
+ help="Show what would be added without actually adding"
272
+ )
273
+ parser.add_argument(
274
+ "--verbose",
275
+ "-v",
276
+ action="store_true",
277
+ help="Enable verbose logging"
278
+ )
279
+
280
+ args = parser.parse_args()
281
+
282
+ # Parse categories
283
+ if args.categories:
284
+ categories = [c.strip() for c in args.categories.split(",")]
285
+ else:
286
+ categories = DEFAULT_CATEGORIES
287
+
288
+ # Parse queries
289
+ queries = None
290
+ if args.queries:
291
+ queries = [q.strip() for q in args.queries.split(",")]
292
+
293
+ print(f"Book Fetch Configuration:")
294
+ print(f" Categories: {categories if not queries else 'N/A (using queries)'}")
295
+ print(f" Queries: {queries or 'N/A (using categories)'}")
296
+ print(f" Year filter: >= {args.year or datetime.now().year}")
297
+ print(f" Max per item: {args.max}")
298
+ print(f" Dry run: {args.dry_run}")
299
+ print()
300
+
301
+ # Fetch books
302
+ if queries:
303
+ stats = fetch_by_query(
304
+ queries=queries,
305
+ max_per_query=args.max,
306
+ dry_run=args.dry_run,
307
+ )
308
+ else:
309
+ stats = fetch_trending_books(
310
+ categories=categories,
311
+ year=args.year,
312
+ max_per_category=args.max,
313
+ dry_run=args.dry_run,
314
+ )
315
+
316
+ print_stats(stats, args.dry_run)
317
+
318
+ return 0 if stats["errors"] == 0 else 1
319
+
320
+
321
+ if __name__ == "__main__":
322
+ sys.exit(main())
scripts/data/validate_data.py CHANGED
@@ -144,9 +144,27 @@ def validate_rec():
144
  print(f" User sequences: {len(seqs):,}")
145
  avg_len = np.mean([len(s) for s in seqs.values()])
146
  print(f" Avg sequence length: {avg_len:.1f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  else:
148
  print(" ⚠️ User sequences not found (run build_sequences.py)")
149
-
150
  print(" ✅ Rec data validation passed")
151
  return True
152
 
 
144
  print(f" User sequences: {len(seqs):,}")
145
  avg_len = np.mean([len(s) for s in seqs.values()])
146
  print(f" Avg sequence length: {avg_len:.1f}")
147
+
148
+ # Time-split: no val items in sequences (prevents sasrec_score leakage)
149
+ if ITEM_MAP.exists():
150
+ with open(ITEM_MAP, "rb") as f:
151
+ item_map = pickle.load(f)
152
+ id_to_item = {v: k for k, v in item_map.items()}
153
+ leaked = 0
154
+ for _, row in val.iterrows():
155
+ uid, val_isbn = str(row["user_id"]), str(row["isbn"])
156
+ if uid not in seqs:
157
+ continue
158
+ val_iid = item_map.get(val_isbn)
159
+ if val_iid is None:
160
+ continue # val item not in map (train-only) -> no leak possible
161
+ if val_iid in seqs[uid]:
162
+ leaked += 1
163
+ check(leaked == 0, f"Time-split violation: {leaked} users have val items in sequence")
164
+ print(" ✅ Time-split OK (no val in sequences)")
165
  else:
166
  print(" ⚠️ User sequences not found (run build_sequences.py)")
167
+
168
  print(" ✅ Rec data validation passed")
169
  return True
170
 
scripts/model/evaluate_rag.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Evaluate RAG retrieval on a Golden Test Set.
4
+
5
+ Replaces "curated examples" with quantitative metrics: Accuracy@K, Recall@K, MRR@K.
6
+ Use human-annotated Query-Book pairs for data-driven evaluation.
7
+
8
+ Usage:
9
+ python scripts/model/evaluate_rag.py
10
+ python scripts/model/evaluate_rag.py --golden data/rag_golden.csv --top_k 10
11
+
12
+ Golden set format (CSV): query, isbn, relevance
13
+ - query: user search string
14
+ - isbn: expected relevant book (1=relevant)
15
+ - Multiple rows per query = multiple relevant books
16
+ """
17
+
18
+ import sys
19
+ from pathlib import Path
20
+
21
+ sys.path.insert(0, str(Path(__file__).resolve().parent.parent.parent))
22
+
23
+ import pandas as pd
24
+ import logging
25
+ from collections import defaultdict
26
+
27
+ from src.recommender import BookRecommender
28
+
29
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ def load_golden(path: Path) -> dict[str, set[str]]:
34
+ """Load golden set: {query -> set of relevant isbns}."""
35
+ df = pd.read_csv(path, comment="#")
36
+ df = df[df.get("relevance", 1) == 1] # Only relevant pairs
37
+ golden = defaultdict(set)
38
+ for _, row in df.iterrows():
39
+ q = str(row["query"]).strip()
40
+ isbn = str(row["isbn"]).strip().replace(".0", "")
41
+ if q and isbn:
42
+ golden[q].add(isbn)
43
+ return dict(golden)
44
+
45
+
46
+ def evaluate_rag(
47
+ golden_path: Path | str = "data/rag_golden.csv",
48
+ top_k: int = 10,
49
+ use_title_match: bool = True,
50
+ ) -> dict:
51
+ """
52
+ Run RAG retrieval on golden set and compute metrics.
53
+
54
+ Returns: dict with accuracy_at_k, recall_at_k, mrr_at_k, n_queries
55
+ """
56
+ golden_path = Path(golden_path)
57
+ if not golden_path.exists():
58
+ # Fallback to example
59
+ alt = Path("data/rag_golden.example.csv")
60
+ if alt.exists():
61
+ logger.warning("Golden set not found at %s, using %s", golden_path, alt)
62
+ golden_path = alt
63
+ else:
64
+ raise FileNotFoundError(
65
+ f"Golden set not found. Create {golden_path} with columns: query,isbn,relevance. "
66
+ "See data/rag_golden.example.csv for format."
67
+ )
68
+
69
+ golden = load_golden(golden_path)
70
+ if not golden:
71
+ raise ValueError("Golden set is empty")
72
+
73
+ logger.info("Evaluating RAG on %d queries from %s", len(golden), golden_path)
74
+
75
+ recommender = BookRecommender()
76
+ isbn_to_title = {}
77
+ if use_title_match:
78
+ try:
79
+ bp = Path("data/books_processed.csv")
80
+ if not bp.exists():
81
+ bp = Path(__file__).resolve().parent.parent.parent / "data" / "books_processed.csv"
82
+ books = pd.read_csv(bp, usecols=["isbn13", "title"])
83
+ books["isbn13"] = books["isbn13"].astype(str).str.replace(r"\.0$", "", regex=True)
84
+ isbn_to_title = books.set_index("isbn13")["title"].to_dict()
85
+ except Exception as e:
86
+ logger.warning("Could not load title map: %s", e)
87
+ use_title_match = False
88
+
89
+ hits_acc = 0
90
+ recall_sum = 0.0
91
+ mrr_sum = 0.0
92
+
93
+ for query, relevant_isbns in golden.items():
94
+ try:
95
+ recs = recommender.get_recommendations(query, top_k=top_k * 2)
96
+ rec_isbns = [r.get("isbn") or r.get("isbn13") for r in recs if r]
97
+ rec_isbns = [str(x).replace(".0", "") for x in rec_isbns if pd.notna(x)]
98
+ rec_top = rec_isbns[:top_k]
99
+
100
+ # Match: exact or title
101
+ def _match(target: str, cand_list: list) -> int:
102
+ for i, c in enumerate(cand_list):
103
+ if str(c).strip() == str(target).strip():
104
+ return i
105
+ if use_title_match:
106
+ t_title = isbn_to_title.get(str(target), "").lower().strip()
107
+ c_title = isbn_to_title.get(str(c), "").lower().strip()
108
+ if t_title and c_title and t_title == c_title:
109
+ return i
110
+ return -1
111
+
112
+ # Accuracy@K: at least one relevant in top-K
113
+ found_any = False
114
+ first_rank = top_k + 1
115
+ count_in_top = 0
116
+
117
+ for rel in relevant_isbns:
118
+ rk = _match(rel, rec_top)
119
+ if rk >= 0:
120
+ found_any = True
121
+ count_in_top += 1
122
+ first_rank = min(first_rank, rk + 1)
123
+
124
+ if found_any:
125
+ hits_acc += 1
126
+ recall_sum += count_in_top / len(relevant_isbns) if relevant_isbns else 0
127
+ if first_rank <= top_k:
128
+ mrr_sum += 1.0 / first_rank
129
+
130
+ except Exception as e:
131
+ logger.warning("Query %r failed: %s", query[:50], e)
132
+
133
+ n = len(golden)
134
+ return {
135
+ "accuracy_at_k": hits_acc / n,
136
+ "recall_at_k": recall_sum / n,
137
+ "mrr_at_k": mrr_sum / n,
138
+ "n_queries": n,
139
+ "top_k": top_k,
140
+ }
141
+
142
+
143
+ def main():
144
+ import argparse
145
+ parser = argparse.ArgumentParser(description="Evaluate RAG on Golden Test Set")
146
+ parser.add_argument("--golden", default="data/rag_golden.csv", help="Path to golden CSV")
147
+ parser.add_argument("--top_k", type=int, default=10)
148
+ parser.add_argument("--no-title-match", action="store_true", help="Disable relaxed title matching")
149
+ args = parser.parse_args()
150
+
151
+ m = evaluate_rag(
152
+ golden_path=args.golden,
153
+ top_k=args.top_k,
154
+ use_title_match=not args.no_title_match,
155
+ )
156
+
157
+ print("\n" + "=" * 50)
158
+ print(" RAG Golden Test Set Evaluation")
159
+ print("=" * 50)
160
+ print(f" Queries: {m['n_queries']}")
161
+ print(f" Top-K: {m['top_k']}")
162
+ print(f" Accuracy@{m['top_k']}: {m['accuracy_at_k']:.4f} (any relevant in top-K)")
163
+ print(f" Recall@{m['top_k']}: {m['recall_at_k']:.4f} (fraction of relevant in top-K)")
164
+ print(f" MRR@{m['top_k']}: {m['mrr_at_k']:.4f} (mean reciprocal rank)")
165
+ print("=" * 50)
166
+
167
+
168
+ if __name__ == "__main__":
169
+ main()
scripts/model/train_din_ranker.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Train DIN (Deep Interest Network) ranker.
4
+
5
+ Uses attention over user behavior sequence w.r.t. target item.
6
+ Reuses SASRec item embeddings as initialization when available.
7
+
8
+ Usage:
9
+ python scripts/model/train_din_ranker.py
10
+ python scripts/model/train_din_ranker.py --max_samples 10000 --epochs 10
11
+
12
+ Input:
13
+ - data/rec/val.csv, train.csv
14
+ - data/rec/user_sequences.pkl, item_map.pkl (from SASRec/YoutubeDNN)
15
+ - data/model/rec/sasrec_model.pth (optional, for init)
16
+
17
+ Output:
18
+ - data/model/ranking/din_ranker.pt
19
+ """
20
+
21
+ import sys
22
+ import os
23
+
24
+ sys.path.append(os.getcwd())
25
+
26
+ import pickle
27
+ import logging
28
+ from pathlib import Path
29
+
30
+ import numpy as np
31
+ import pandas as pd
32
+ import torch
33
+ import torch.nn as nn
34
+ import torch.nn.functional as F
35
+ from torch.utils.data import Dataset, DataLoader
36
+ from tqdm import tqdm
37
+
38
+ from src.ranking.din import DIN
39
+ from src.recall.fusion import RecallFusion
40
+
41
+ logging.basicConfig(
42
+ level=logging.INFO,
43
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
44
+ )
45
+ logger = logging.getLogger(__name__)
46
+
47
+
48
+ def build_din_data(
49
+ data_dir: str = "data/rec",
50
+ model_dir: str = "data/model/recall",
51
+ neg_ratio: int = 4,
52
+ max_samples: int = 20000,
53
+ ) -> tuple[pd.DataFrame, dict, dict]:
54
+ """
55
+ Build (user_id, isbn, label) samples with hard negatives.
56
+ Returns (df, user_sequences, item_map).
57
+ """
58
+ logger.info("Building DIN training data...")
59
+ val_df = pd.read_csv(f"{data_dir}/val.csv")
60
+ all_items = pd.read_csv(f"{data_dir}/train.csv")["isbn"].astype(str).unique()
61
+
62
+ if len(val_df) > max_samples:
63
+ val_df = val_df.sample(n=max_samples, random_state=42).reset_index(drop=True)
64
+
65
+ fusion = RecallFusion(data_dir, model_dir)
66
+ fusion.load_models()
67
+
68
+ with open(f"{data_dir}/user_sequences.pkl", "rb") as f:
69
+ user_sequences = pickle.load(f)
70
+ with open(f"{data_dir}/item_map.pkl", "rb") as f:
71
+ item_map = pickle.load(f)
72
+
73
+ rows = []
74
+ for _, row in tqdm(val_df.iterrows(), total=len(val_df), desc="Mining samples"):
75
+ user_id = str(row["user_id"])
76
+ pos_isbn = str(row["isbn"])
77
+
78
+ user_rows = [{"user_id": user_id, "isbn": pos_isbn, "label": 1}]
79
+
80
+ try:
81
+ recall_items = fusion.get_recall_items(user_id, k=50)
82
+ hard_negs = [item for item, _ in recall_items if item != pos_isbn][:neg_ratio]
83
+ except Exception:
84
+ hard_negs = []
85
+
86
+ for neg_isbn in hard_negs:
87
+ user_rows.append({"user_id": user_id, "isbn": str(neg_isbn), "label": 0})
88
+
89
+ n_remaining = neg_ratio - len(hard_negs)
90
+ if n_remaining > 0:
91
+ random_negs = np.random.choice(all_items, size=n_remaining, replace=False)
92
+ for neg_isbn in random_negs:
93
+ user_rows.append({"user_id": user_id, "isbn": str(neg_isbn), "label": 0})
94
+
95
+ rows.extend(user_rows)
96
+
97
+ df = pd.DataFrame(rows)
98
+ logger.info(f"Built {len(df)} samples")
99
+ return df, user_sequences, item_map
100
+
101
+
102
+ class DINDataset(Dataset):
103
+ """Dataset for DIN: (user_hist, target_item_id, label) and optional aux features."""
104
+
105
+ def __init__(
106
+ self,
107
+ df: pd.DataFrame,
108
+ user_sequences: dict,
109
+ item_map: dict,
110
+ max_hist_len: int = 50,
111
+ aux_df: pd.DataFrame | None = None,
112
+ aux_cols: list[str] | None = None,
113
+ ):
114
+ self.samples = []
115
+ self.aux_df = aux_df
116
+ self.aux_cols = aux_cols or []
117
+ for idx, (_, row) in enumerate(df.iterrows()):
118
+ user_id = str(row["user_id"])
119
+ isbn = str(row["isbn"])
120
+ label = int(row["label"])
121
+ target_id = item_map.get(isbn, 0)
122
+ if target_id == 0:
123
+ continue
124
+
125
+ hist = user_sequences.get(user_id, [])
126
+ if hist and isinstance(hist[0], str):
127
+ hist = [item_map.get(h, 0) for h in hist if item_map.get(h, 0) > 0]
128
+ hist = [x for x in hist if x != target_id][-max_hist_len:]
129
+
130
+ self.samples.append((hist, target_id, label, idx))
131
+
132
+ def __len__(self) -> int:
133
+ return len(self.samples)
134
+
135
+ def __getitem__(self, idx: int):
136
+ hist, target_id, label, df_idx = self.samples[idx]
137
+ max_len = 50
138
+ padded = np.zeros(max_len, dtype=np.int64)
139
+ padded[: len(hist)] = hist
140
+ out = (
141
+ torch.LongTensor(padded),
142
+ torch.LongTensor([target_id]).squeeze(0),
143
+ torch.FloatTensor([label]).squeeze(0),
144
+ )
145
+ if self.aux_df is not None and self.aux_cols:
146
+ aux_row = self.aux_df.iloc[df_idx][self.aux_cols].values.astype(np.float32)
147
+ out = out + (torch.FloatTensor(aux_row),)
148
+ return out
149
+
150
+
151
+ def train_din(
152
+ data_dir: str = "data/rec",
153
+ model_dir: str = "data/model",
154
+ recall_dir: str = "data/model/recall",
155
+ max_samples: int = 20000,
156
+ max_hist_len: int = 50,
157
+ embed_dim: int = 64,
158
+ epochs: int = 10,
159
+ batch_size: int = 256,
160
+ lr: float = 1e-3,
161
+ use_aux: bool = False,
162
+ ) -> None:
163
+ rank_dir = Path(model_dir) / "ranking"
164
+ rank_dir.mkdir(parents=True, exist_ok=True)
165
+
166
+ df, user_sequences, item_map = build_din_data(
167
+ data_dir, recall_dir, neg_ratio=4, max_samples=max_samples
168
+ )
169
+ num_items = len(item_map)
170
+
171
+ aux_df = None
172
+ aux_cols: list[str] = []
173
+ if use_aux:
174
+ from src.ranking.features import FeatureEngineer
175
+ fe = FeatureEngineer(data_dir, recall_dir)
176
+ fe.load_base_data()
177
+ logger.info("Generating aux features for DIN...")
178
+ aux_df = fe.create_dateset(df)
179
+ aux_cols = [c for c in aux_df.columns if c not in ("label", "user_id", "isbn")]
180
+ logger.info("Aux features: %s", aux_cols)
181
+
182
+ num_aux = len(aux_cols)
183
+ dataset = DINDataset(
184
+ df, user_sequences, item_map,
185
+ max_hist_len=max_hist_len,
186
+ aux_df=aux_df,
187
+ aux_cols=aux_cols if aux_cols else None,
188
+ )
189
+ loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
190
+
191
+ pretrained_emb = None
192
+ sasrec_path = Path(model_dir) / "rec" / "sasrec_model.pth"
193
+ if sasrec_path.exists():
194
+ try:
195
+ state = torch.load(sasrec_path, map_location="cpu", weights_only=False)
196
+ emb = state.get("item_emb.weight")
197
+ if emb is not None:
198
+ pretrained_emb = emb.numpy()
199
+ logger.info("Loaded SASRec item_emb for DIN init: %s", pretrained_emb.shape)
200
+ except Exception as e:
201
+ logger.warning("Could not load SASRec init: %s", e)
202
+
203
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
204
+ if torch.backends.mps.is_available():
205
+ device = torch.device("mps")
206
+
207
+ model = DIN(
208
+ num_items=num_items,
209
+ embed_dim=embed_dim,
210
+ max_hist_len=max_hist_len,
211
+ num_aux=num_aux,
212
+ pretrained_item_emb=pretrained_emb,
213
+ ).to(device)
214
+ opt = torch.optim.Adam(model.parameters(), lr=lr)
215
+
216
+ for ep in range(epochs):
217
+ model.train()
218
+ total_loss = 0.0
219
+ n_batches = 0
220
+ for batch in tqdm(loader, desc=f"Epoch {ep+1}/{epochs}"):
221
+ hist = batch[0].to(device)
222
+ target = batch[1].to(device)
223
+ label = batch[2].to(device)
224
+ aux = batch[3].to(device) if len(batch) > 3 else None
225
+ opt.zero_grad()
226
+ logits = model(hist, target, aux)
227
+ loss = F.binary_cross_entropy_with_logits(logits, label)
228
+ loss.backward()
229
+ opt.step()
230
+ total_loss += loss.item()
231
+ n_batches += 1
232
+ avg = total_loss / max(n_batches, 1)
233
+ logger.info(f"Epoch {ep+1} loss: {avg:.4f}")
234
+
235
+ ckpt = {
236
+ "model": model,
237
+ "item_map": item_map,
238
+ "max_hist_len": max_hist_len,
239
+ "aux_feature_names": aux_cols,
240
+ }
241
+ out_path = rank_dir / "din_ranker.pt"
242
+ torch.save(ckpt, out_path)
243
+ logger.info("DIN ranker saved to %s", out_path)
244
+
245
+ with open(Path(data_dir) / "user_sequences.pkl", "wb") as f:
246
+ pickle.dump(user_sequences, f)
247
+
248
+
249
+ if __name__ == "__main__":
250
+ import argparse
251
+
252
+ parser = argparse.ArgumentParser(description="Train DIN ranker")
253
+ parser.add_argument("--max_samples", type=int, default=20000)
254
+ parser.add_argument("--epochs", type=int, default=10)
255
+ parser.add_argument("--batch_size", type=int, default=256)
256
+ parser.add_argument("--aux", action="store_true", help="Use aux features from FeatureEngineer")
257
+ args = parser.parse_args()
258
+
259
+ train_din(
260
+ max_samples=args.max_samples,
261
+ epochs=args.epochs,
262
+ batch_size=args.batch_size,
263
+ use_aux=args.aux,
264
+ )
scripts/model/train_intent_router.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Train model-based intent classifier for Query Router.
4
+
5
+ Replaces rule-based heuristics with TF-IDF + LogisticRegression (or FastText/DistilBERT).
6
+ Uses synthetic seed data; extend with real labeled queries via --data CSV.
7
+
8
+ Usage:
9
+ python scripts/model/train_intent_router.py
10
+ python scripts/model/train_intent_router.py --data data/intent_labels.csv
11
+ python scripts/model/train_intent_router.py --backend fasttext
12
+ python scripts/model/train_intent_router.py --backend distilbert
13
+
14
+ Output:
15
+ data/model/intent_classifier.pkl (or .bin for fasttext)
16
+ """
17
+
18
+ import sys
19
+ from pathlib import Path
20
+
21
+ sys.path.insert(0, str(Path(__file__).resolve().parent.parent.parent))
22
+
23
+ import joblib
24
+ import logging
25
+
26
+ import pandas as pd
27
+
28
+ from src.core.intent_classifier import train_classifier, INTENTS
29
+
30
+ logging.basicConfig(level=logging.INFO, format="%(message)s")
31
+ logger = logging.getLogger(__name__)
32
+
33
+ # Synthetic training data: (query, intent)
34
+ # Extend with real user queries for better generalization
35
+ SEED_DATA = [
36
+ # small_to_big: detail-oriented, plot/review focused
37
+ ("book with twist ending", "small_to_big"),
38
+ ("unreliable narrator", "small_to_big"),
39
+ ("spoiler about the ending", "small_to_big"),
40
+ ("what did readers think", "small_to_big"),
41
+ ("opinion on the book", "small_to_big"),
42
+ ("hidden details in the story", "small_to_big"),
43
+ ("did anyone cry reading this", "small_to_big"),
44
+ ("review of the book", "small_to_big"),
45
+ ("plot twist reveal", "small_to_big"),
46
+ ("unreliable narrator twist", "small_to_big"),
47
+ ("readers who loved the ending", "small_to_big"),
48
+ ("spoiler what happens at the end", "small_to_big"),
49
+ # fast: short keyword queries
50
+ ("AI book", "fast"),
51
+ ("Python", "fast"),
52
+ ("romance", "fast"),
53
+ ("machine learning", "fast"),
54
+ ("science fiction", "fast"),
55
+ ("best AI book", "fast"),
56
+ ("Python programming", "fast"),
57
+ ("self help", "fast"),
58
+ ("business", "fast"),
59
+ ("fiction", "fast"),
60
+ ("thriller", "fast"),
61
+ ("mystery novel", "fast"),
62
+ ("finance", "fast"),
63
+ ("history", "fast"),
64
+ ("psychology", "fast"),
65
+ ("data science", "fast"),
66
+ ("cooking", "fast"),
67
+ ("music", "fast"),
68
+ ("art", "fast"),
69
+ ("philosophy", "fast"),
70
+ # deep: natural language, complex queries
71
+ ("What are the best books about artificial intelligence for beginners", "deep"),
72
+ ("I'm looking for something similar to Harry Potter", "deep"),
73
+ ("Books that help you understand machine learning", "deep"),
74
+ ("Recommend me a book like Sapiens but about technology", "deep"),
75
+ ("I want to learn about psychology and human behavior", "deep"),
76
+ ("What should I read if I liked 1984", "deep"),
77
+ ("Looking for books on startup founding and entrepreneurship", "deep"),
78
+ ("Can you suggest books about climate change and sustainability", "deep"),
79
+ ("I need a book that explains quantum physics simply", "deep"),
80
+ ("Books for someone who wants to improve their writing skills", "deep"),
81
+ ("What are some good fiction books set in Japan", "deep"),
82
+ ("Recommendations for someone getting into philosophy", "deep"),
83
+ ("Books that discuss the future of work and automation", "deep"),
84
+ ("I'm interested in biographies of scientists", "deep"),
85
+ ("Something light and funny for a long flight", "deep"),
86
+ ("Books about the history of mathematics", "deep"),
87
+ ("Recommend me novels with strong female protagonists", "deep"),
88
+ ("What to read to understand economics", "deep"),
89
+ ("Books on meditation and mindfulness", "deep"),
90
+ ]
91
+
92
+
93
+ def load_training_data(data_path: Path | None) -> tuple[list[str], list[str]]:
94
+ """Load (queries, labels) from SEED_DATA + optional CSV."""
95
+ queries = [q for q, _ in SEED_DATA]
96
+ labels = [l for _, l in SEED_DATA]
97
+
98
+ if data_path and data_path.exists():
99
+ df = pd.read_csv(data_path)
100
+ q_col = "query" if "query" in df.columns else df.columns[0]
101
+ l_col = "intent" if "intent" in df.columns else df.columns[1]
102
+ extra_q = df[q_col].astype(str).tolist()
103
+ extra_l = df[l_col].astype(str).tolist()
104
+ queries.extend(extra_q)
105
+ labels.extend(extra_l)
106
+ logger.info("Loaded %d extra samples from %s", len(extra_q), data_path)
107
+
108
+ return queries, labels
109
+
110
+
111
+ def main():
112
+ import argparse
113
+ parser = argparse.ArgumentParser(description="Train intent classifier")
114
+ parser.add_argument("--data", type=Path, default=None, help="CSV with query,intent columns")
115
+ parser.add_argument("--backend", choices=["tfidf", "fasttext", "distilbert"], default="tfidf")
116
+ args = parser.parse_args()
117
+
118
+ project_root = Path(__file__).resolve().parent.parent.parent
119
+ out_dir = project_root / "data" / "model"
120
+ out_dir.mkdir(parents=True, exist_ok=True)
121
+
122
+ queries, labels = load_training_data(args.data)
123
+
124
+ logger.info("Training intent classifier (%s) on %d samples...", args.backend, len(queries))
125
+ result = train_classifier(queries, labels, backend=args.backend)
126
+
127
+ if args.backend == "fasttext":
128
+ out_path = out_dir / "intent_classifier.bin"
129
+ result.save_model(str(out_path))
130
+ else:
131
+ out_path = out_dir / "intent_classifier.pkl"
132
+ if args.backend == "distilbert":
133
+ joblib.dump(result, out_path) # dict with pipeline, backend, etc.
134
+ else:
135
+ joblib.dump({"pipeline": result, "backend": "tfidf"}, out_path)
136
+
137
+ logger.info("Saved to %s", out_path)
138
+
139
+ # Quick sanity check
140
+ for intent in INTENTS:
141
+ sample = next((q for q, l in zip(queries, labels) if l == intent), None)
142
+ if sample:
143
+ if args.backend == "fasttext":
144
+ pred = result.predict(sample)[0][0].replace("__label__", "")
145
+ elif args.backend == "distilbert":
146
+ from transformers import pipeline
147
+ pipe = pipeline("zero-shot-classification", model="distilbert-base-uncased", device=-1)
148
+ pred = pipe(sample, INTENTS, multi_label=False)["labels"][0]
149
+ else:
150
+ pred = result.predict([sample])[0]
151
+ ok = "✓" if pred == intent else "✗"
152
+ logger.info(" %s %s: %r -> %s", ok, intent, sample[:40], pred)
153
+
154
+
155
+ if __name__ == "__main__":
156
+ main()
scripts/model/train_ranker.py CHANGED
@@ -15,18 +15,15 @@ Input:
15
  - data/rec/train.csv (for fallback random negatives)
16
  - data/model/recall/*.pkl (recall models for hard negative mining)
17
 
18
- Output (Standard):
19
- - data/model/ranking/lgbm_ranker.txt
20
-
21
- Output (Stacking):
22
- - data/model/ranking/lgbm_ranker.txt (full retrained LGB)
23
- - data/model/ranking/xgb_ranker.json (full retrained XGB)
24
- - data/model/ranking/stacking_meta.pkl (LogisticRegression meta-model)
25
 
26
  Negative Sampling Strategy:
27
  - Hard negatives: items from recall results that are NOT the positive
28
  - Random negatives: fill remaining slots if recall returns too few
29
- - This teaches the ranker to distinguish between "close but wrong" vs "right"
30
  """
31
 
32
  import sys
 
15
  - data/rec/train.csv (for fallback random negatives)
16
  - data/model/recall/*.pkl (recall models for hard negative mining)
17
 
18
+ TIME-SPLIT (no leakage):
19
+ - Recall models (SASRec, etc.) are trained on train.csv only.
20
+ - Ranking uses val.csv for labels; recall for hard negatives.
21
+ - sasrec_score and user_seq_emb come from train-only SASRec.
22
+ - Pipeline order: split -> build_sequences(train-only) -> recall(train) -> ranker(val).
 
 
23
 
24
  Negative Sampling Strategy:
25
  - Hard negatives: items from recall results that are NOT the positive
26
  - Random negatives: fill remaining slots if recall returns too few
 
27
  """
28
 
29
  import sys
scripts/run_pipeline.py CHANGED
@@ -44,6 +44,7 @@ class Pipeline:
44
  skip_models: bool = False,
45
  skip_index: bool = False,
46
  stacking: bool = False,
 
47
  ):
48
  self.project_root = Path(project_root)
49
  self.data_dir = self.project_root / "data"
@@ -53,6 +54,7 @@ class Pipeline:
53
  self.skip_models = skip_models
54
  self.skip_index = skip_index
55
  self.stacking = stacking
 
56
 
57
  def _run_step(self, name: str, fn, *args, **kwargs):
58
  """Run a step with timing log."""
@@ -153,8 +155,19 @@ class Pipeline:
153
  from scripts.model.train_ranker import train_ranker, train_stacking
154
  self._run_step("Train Ranker", train_stacking if self.stacking else train_ranker)
155
 
 
 
 
 
 
 
 
 
 
 
 
156
  def run_evaluation(self) -> None:
157
- """Stage 5: Validation."""
158
  def _validate():
159
  from scripts.data.validate_data import (
160
  validate_raw, validate_processed, validate_rec,
@@ -168,6 +181,18 @@ class Pipeline:
168
 
169
  self._run_step("Validate pipeline", _validate)
170
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  def run(self, stage: str = "all") -> None:
172
  """Execute full pipeline: Data Cleaning -> Training -> Evaluation."""
173
  logger.info("=" * 60)
@@ -201,6 +226,7 @@ def main():
201
  parser.add_argument("--validate-only", action="store_true", help="Only run validation")
202
  parser.add_argument("--device", default=None, help="Device for ML (cpu/cuda/mps)")
203
  parser.add_argument("--stacking", action="store_true", help="Enable stacking ranker")
 
204
  args = parser.parse_args()
205
 
206
  if args.validate_only:
@@ -213,6 +239,7 @@ def main():
213
  skip_models=args.skip_models,
214
  skip_index=args.skip_index,
215
  stacking=args.stacking,
 
216
  )
217
  pipeline.run(stage=args.stage)
218
 
 
44
  skip_models: bool = False,
45
  skip_index: bool = False,
46
  stacking: bool = False,
47
+ train_din: bool = False,
48
  ):
49
  self.project_root = Path(project_root)
50
  self.data_dir = self.project_root / "data"
 
54
  self.skip_models = skip_models
55
  self.skip_index = skip_index
56
  self.stacking = stacking
57
+ self.train_din = train_din
58
 
59
  def _run_step(self, name: str, fn, *args, **kwargs):
60
  """Run a step with timing log."""
 
155
  from scripts.model.train_ranker import train_ranker, train_stacking
156
  self._run_step("Train Ranker", train_stacking if self.stacking else train_ranker)
157
 
158
+ from scripts.model.train_intent_router import main as train_intent
159
+ self._run_step("Train intent classifier", train_intent)
160
+
161
+ if getattr(self, "train_din", False):
162
+ from scripts.model.train_din_ranker import train_din
163
+ self._run_step("Train DIN ranker", lambda: train_din(
164
+ data_dir=str(self.rec_dir),
165
+ model_dir=str(self.model_dir),
166
+ recall_dir=str(self.model_dir / "recall"),
167
+ ))
168
+
169
  def run_evaluation(self) -> None:
170
+ """Stage 5: Validation + RAG Golden Test Set (if exists)."""
171
  def _validate():
172
  from scripts.data.validate_data import (
173
  validate_raw, validate_processed, validate_rec,
 
181
 
182
  self._run_step("Validate pipeline", _validate)
183
 
184
+ # RAG Golden Test Set evaluation (optional)
185
+ golden = self.rec_dir.parent / "rag_golden.csv"
186
+ if not golden.exists():
187
+ golden = self.rec_dir.parent / "rag_golden.example.csv"
188
+ if golden.exists():
189
+ def _run_rag_eval():
190
+ from scripts.model.evaluate_rag import evaluate_rag
191
+ m = evaluate_rag(str(golden))
192
+ logger.info("RAG Accuracy@%d: %.4f Recall@%d: %.4f MRR@%d: %.4f",
193
+ m["top_k"], m["accuracy_at_k"], m["top_k"], m["recall_at_k"], m["top_k"], m["mrr_at_k"])
194
+ self._run_step("RAG Golden Test Set", _run_rag_eval)
195
+
196
  def run(self, stage: str = "all") -> None:
197
  """Execute full pipeline: Data Cleaning -> Training -> Evaluation."""
198
  logger.info("=" * 60)
 
226
  parser.add_argument("--validate-only", action="store_true", help="Only run validation")
227
  parser.add_argument("--device", default=None, help="Device for ML (cpu/cuda/mps)")
228
  parser.add_argument("--stacking", action="store_true", help="Enable stacking ranker")
229
+ parser.add_argument("--din", action="store_true", help="Train DIN ranker (deep model)")
230
  args = parser.parse_args()
231
 
232
  if args.validate_only:
 
239
  skip_models=args.skip_models,
240
  skip_index=args.skip_index,
241
  stacking=args.stacking,
242
+ train_din=args.din,
243
  )
244
  pipeline.run(stage=args.stage)
245
 
src/core/freshness_monitor.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data Freshness Monitor.
3
+
4
+ Provides insights into the freshness of the local book database:
5
+ - Distribution of books by publication year
6
+ - Detection of data staleness
7
+ - Recommendations for when to trigger updates
8
+
9
+ This module helps the system decide when to rely on local data vs.
10
+ triggering external API fallbacks.
11
+ """
12
+ from datetime import datetime
13
+ from typing import Optional
14
+
15
+ from src.core.metadata_store import metadata_store
16
+ from src.utils import setup_logger
17
+
18
+ logger = setup_logger(__name__)
19
+
20
+
21
+ class FreshnessMonitor:
22
+ """
23
+ Monitor data freshness and provide staleness detection.
24
+
25
+ Usage:
26
+ monitor = FreshnessMonitor()
27
+ stats = monitor.get_data_stats()
28
+ if monitor.is_stale_for_query("latest 2025 books"):
29
+ # Trigger web search fallback
30
+ """
31
+
32
+ # Years considered "recent" for freshness calculations
33
+ RECENT_YEARS_THRESHOLD = 2
34
+
35
+ def __init__(self):
36
+ self._cache = {}
37
+ self._cache_timestamp = None
38
+ self._cache_ttl_seconds = 300 # 5 minutes
39
+
40
+ def _is_cache_valid(self) -> bool:
41
+ """Check if cached stats are still valid."""
42
+ if not self._cache or not self._cache_timestamp:
43
+ return False
44
+ age = (datetime.now() - self._cache_timestamp).total_seconds()
45
+ return age < self._cache_ttl_seconds
46
+
47
+ def get_data_stats(self, force_refresh: bool = False) -> dict:
48
+ """
49
+ Get comprehensive statistics about data freshness.
50
+
51
+ Returns:
52
+ Dict with:
53
+ - total_books: Total number of books in database
54
+ - newest_year: Year of most recently published book
55
+ - oldest_year: Year of oldest book
56
+ - books_by_year: Dict mapping year -> count
57
+ - recent_books_count: Books published in last N years
58
+ - data_cutoff_year: Effective "knowledge cutoff" year
59
+ - freshness_score: 0-100 score indicating data freshness
60
+ """
61
+ if not force_refresh and self._is_cache_valid():
62
+ return self._cache
63
+
64
+ stats = {
65
+ "total_books": 0,
66
+ "newest_year": None,
67
+ "oldest_year": None,
68
+ "books_by_year": {},
69
+ "recent_books_count": 0,
70
+ "data_cutoff_year": None,
71
+ "freshness_score": 0,
72
+ "last_checked": datetime.now().isoformat(),
73
+ }
74
+
75
+ try:
76
+ stats["total_books"] = metadata_store.get_book_count()
77
+ stats["books_by_year"] = metadata_store.get_books_by_year_distribution()
78
+
79
+ if stats["books_by_year"]:
80
+ years = sorted(stats["books_by_year"].keys())
81
+ stats["newest_year"] = max(years)
82
+ stats["oldest_year"] = min(years)
83
+ stats["data_cutoff_year"] = stats["newest_year"]
84
+
85
+ # Count recent books (last N years)
86
+ current_year = datetime.now().year
87
+ recent_threshold = current_year - self.RECENT_YEARS_THRESHOLD
88
+ stats["recent_books_count"] = sum(
89
+ count for year, count in stats["books_by_year"].items()
90
+ if year >= recent_threshold
91
+ )
92
+
93
+ # Calculate freshness score (0-100)
94
+ # Based on: newest year relative to current year
95
+ years_behind = current_year - (stats["newest_year"] or current_year)
96
+ stats["freshness_score"] = max(0, 100 - (years_behind * 25))
97
+
98
+ self._cache = stats
99
+ self._cache_timestamp = datetime.now()
100
+
101
+ except Exception as e:
102
+ logger.error(f"FreshnessMonitor.get_data_stats failed: {e}")
103
+
104
+ return stats
105
+
106
+ def is_stale(self, target_year: Optional[int] = None) -> bool:
107
+ """
108
+ Check if local data is too old for a given target year.
109
+
110
+ Args:
111
+ target_year: Year the user is asking about (default: current year)
112
+
113
+ Returns:
114
+ True if data is stale and web fallback should be triggered
115
+ """
116
+ if target_year is None:
117
+ target_year = datetime.now().year
118
+
119
+ stats = self.get_data_stats()
120
+ newest_year = stats.get("newest_year")
121
+
122
+ if newest_year is None:
123
+ return True # No data at all
124
+
125
+ # Stale if target year is newer than our newest data
126
+ return target_year > newest_year
127
+
128
+ def is_stale_for_query(self, query: str) -> bool:
129
+ """
130
+ Analyze a query and determine if data is stale for it.
131
+
132
+ Args:
133
+ query: User's search query
134
+
135
+ Returns:
136
+ True if web fallback should be triggered
137
+ """
138
+ from src.core.web_search import extract_year_from_query
139
+
140
+ target_year = extract_year_from_query(query)
141
+
142
+ if target_year is None:
143
+ # No year requirement - check freshness score
144
+ stats = self.get_data_stats()
145
+ # Trigger fallback if data is more than 2 years old
146
+ return stats.get("freshness_score", 100) < 50
147
+
148
+ return self.is_stale(target_year)
149
+
150
+ def get_coverage_for_year(self, year: int) -> dict:
151
+ """
152
+ Get coverage statistics for a specific year.
153
+
154
+ Returns:
155
+ Dict with: count, percentage of total, is_well_covered
156
+ """
157
+ stats = self.get_data_stats()
158
+ year_count = stats["books_by_year"].get(year, 0)
159
+ total = stats["total_books"] or 1
160
+
161
+ return {
162
+ "year": year,
163
+ "count": year_count,
164
+ "percentage": round(year_count / total * 100, 2),
165
+ "is_well_covered": year_count >= 100, # Arbitrary threshold
166
+ }
167
+
168
+ def recommend_update_categories(self) -> list[str]:
169
+ """
170
+ Recommend categories that should be updated.
171
+
172
+ Returns:
173
+ List of category names that need fresh data
174
+ """
175
+ # This would require category-level year tracking
176
+ # For now, return common categories that benefit from freshness
177
+ return [
178
+ "fiction",
179
+ "thriller",
180
+ "science fiction",
181
+ "fantasy",
182
+ "mystery",
183
+ "self-help",
184
+ "business",
185
+ ]
186
+
187
+ def get_summary(self) -> str:
188
+ """
189
+ Get a human-readable summary of data freshness.
190
+
191
+ Returns:
192
+ Formatted string describing data freshness status
193
+ """
194
+ stats = self.get_data_stats()
195
+
196
+ lines = [
197
+ f"Data Freshness Report",
198
+ f"=" * 40,
199
+ f"Total books: {stats['total_books']:,}",
200
+ f"Newest book year: {stats['newest_year'] or 'Unknown'}",
201
+ f"Data cutoff: {stats['data_cutoff_year'] or 'Unknown'}",
202
+ f"Recent books (last {self.RECENT_YEARS_THRESHOLD} years): {stats['recent_books_count']:,}",
203
+ f"Freshness score: {stats['freshness_score']}/100",
204
+ ]
205
+
206
+ current_year = datetime.now().year
207
+ if stats["newest_year"] and stats["newest_year"] < current_year:
208
+ years_behind = current_year - stats["newest_year"]
209
+ lines.append(f"")
210
+ lines.append(f"WARNING: Data is {years_behind} year(s) behind current year.")
211
+ lines.append(f"Consider running: python scripts/data/fetch_new_books.py --year {current_year}")
212
+
213
+ return "\n".join(lines)
214
+
215
+
216
+ # Global instance
217
+ freshness_monitor = FreshnessMonitor()
218
+
219
+
220
+ # Convenience function for quick checks
221
+ def is_data_fresh_enough(query: str) -> bool:
222
+ """
223
+ Quick check if local data is fresh enough for a query.
224
+
225
+ Args:
226
+ query: User's search query
227
+
228
+ Returns:
229
+ True if local data is sufficient, False if web fallback recommended
230
+ """
231
+ return not freshness_monitor.is_stale_for_query(query)
src/core/intent_classifier.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model-based intent classifier for Query Router.
3
+
4
+ Replaces brittle rule-based heuristics with a trained classifier.
5
+ Backends: tfidf (default), fasttext, distilbert.
6
+
7
+ Intents: small_to_big (detail), fast (keyword), deep (natural language)
8
+ """
9
+
10
+ import logging
11
+ from pathlib import Path
12
+ from typing import Optional
13
+
14
+ import joblib
15
+ from sklearn.feature_extraction.text import TfidfVectorizer
16
+ from sklearn.linear_model import LogisticRegression
17
+ from sklearn.pipeline import Pipeline
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ INTENTS = ["small_to_big", "fast", "deep"]
22
+
23
+
24
+ class IntentClassifier:
25
+ """
26
+ Intent classifier with pluggable backends:
27
+ - tfidf: TF-IDF + LogisticRegression (~1–2ms)
28
+ - fasttext: FastText (~1ms, requires fasttext package)
29
+ - distilbert: Zero-shot DistilBERT (~50–100ms, higher accuracy)
30
+ """
31
+
32
+ def __init__(self, model_path: Optional[Path] = None):
33
+ self.pipeline: Optional[Pipeline] = None
34
+ self._fasttext_model = None
35
+ self._distilbert_pipeline = None
36
+ self._backend = "tfidf"
37
+ self.model_path = Path(model_path) if model_path else None
38
+
39
+ def load(self, path: Optional[Path] = None) -> bool:
40
+ """Load trained model from disk."""
41
+ p = path or self.model_path
42
+ if not p:
43
+ return False
44
+ p = Path(p)
45
+ base = p.parent if p.suffix in (".pkl", ".bin") else p
46
+ pkl_path = p if p.suffix == ".pkl" else base / "intent_classifier.pkl"
47
+ bin_path = p if p.suffix == ".bin" else base / "intent_classifier.bin"
48
+
49
+ # Try .pkl first (tfidf or distilbert)
50
+ if pkl_path.exists():
51
+ try:
52
+ data = joblib.load(pkl_path)
53
+ if isinstance(data, dict):
54
+ self.pipeline = data.get("pipeline")
55
+ self._backend = data.get("backend", "tfidf")
56
+ if self._backend == "distilbert":
57
+ self._load_distilbert(data)
58
+ elif self.pipeline is None and self._backend == "tfidf":
59
+ self.pipeline = data
60
+ else:
61
+ self.pipeline = data
62
+ self.model_path = pkl_path
63
+ logger.info("Intent classifier loaded from %s (backend=%s)", pkl_path, self._backend)
64
+ return True
65
+ except Exception as e:
66
+ logger.warning("Failed to load intent classifier: %s", e)
67
+
68
+ # Try .bin (FastText)
69
+ if bin_path.exists():
70
+ try:
71
+ import fasttext
72
+ self._fasttext_model = fasttext.load_model(str(bin_path))
73
+ self._backend = "fasttext"
74
+ self.model_path = bin_path
75
+ logger.info("Intent classifier loaded from %s (FastText)", bin_path)
76
+ return True
77
+ except ImportError:
78
+ logger.warning("FastText not installed; pip install fasttext")
79
+ except Exception as e:
80
+ logger.warning("Failed to load FastText: %s", e)
81
+
82
+ return False
83
+
84
+ def _load_distilbert(self, data: dict) -> None:
85
+ """Lazy-load DistilBERT pipeline from saved config."""
86
+ model_name = data.get("distilbert_model", "distilbert-base-uncased")
87
+ try:
88
+ from transformers import pipeline
89
+ self._distilbert_pipeline = pipeline(
90
+ "zero-shot-classification",
91
+ model=model_name,
92
+ device=-1,
93
+ )
94
+ except Exception as e:
95
+ logger.warning("DistilBERT pipeline load failed: %s", e)
96
+ self.pipeline = None # Use distilbert, not sklearn pipeline
97
+
98
+ def predict(self, query: str) -> str:
99
+ """Predict intent for a query. Returns one of small_to_big, fast, deep."""
100
+ q = query.strip()
101
+ if not q:
102
+ return "deep"
103
+
104
+ if self._fasttext_model is not None:
105
+ pred = self._fasttext_model.predict(q)
106
+ return pred[0][0].replace("__label__", "")
107
+
108
+ if self._distilbert_pipeline is not None:
109
+ out = self._distilbert_pipeline(q, INTENTS, multi_label=False)
110
+ return out["labels"][0]
111
+
112
+ if self.pipeline is None:
113
+ raise RuntimeError("Intent classifier not loaded; call load() first")
114
+ return str(self.pipeline.predict([q])[0])
115
+
116
+ def predict_proba(self, query: str) -> dict[str, float]:
117
+ """Return intent probabilities for debugging."""
118
+ q = query.strip()
119
+ if not q:
120
+ return {i: 1.0 / len(INTENTS) for i in INTENTS}
121
+
122
+ if self._fasttext_model is not None:
123
+ pred = self._fasttext_model.predict(q, k=len(INTENTS))
124
+ return dict(zip([l.replace("__label__", "") for l in pred[0]], pred[1]))
125
+
126
+ if self._distilbert_pipeline is not None:
127
+ out = self._distilbert_pipeline(q, INTENTS, multi_label=False)
128
+ return dict(zip(out["labels"], out["scores"]))
129
+
130
+ if self.pipeline is None:
131
+ raise RuntimeError("Intent classifier not loaded")
132
+ probs = self.pipeline.predict_proba([q])[0]
133
+ last_step = self.pipeline.steps[-1][1]
134
+ classes = getattr(last_step, "classes_", INTENTS)
135
+ return dict(zip(classes, probs))
136
+
137
+
138
+ def train_classifier(
139
+ queries: list[str],
140
+ labels: list[str],
141
+ max_features: int = 5000,
142
+ C: float = 1.0,
143
+ backend: str = "tfidf",
144
+ ):
145
+ """
146
+ Train intent classifier. Returns pipeline (tfidf), model (fasttext), or dict (distilbert).
147
+ """
148
+ if backend == "fasttext":
149
+ return _train_fasttext(queries, labels)
150
+ if backend == "distilbert":
151
+ return _train_distilbert(queries, labels)
152
+ # tfidf default
153
+ pipeline = Pipeline([
154
+ ("tfidf", TfidfVectorizer(
155
+ max_features=max_features,
156
+ ngram_range=(1, 2),
157
+ min_df=1,
158
+ lowercase=True,
159
+ )),
160
+ ("clf", LogisticRegression(
161
+ C=C,
162
+ max_iter=500,
163
+ class_weight="balanced",
164
+ random_state=42,
165
+ )),
166
+ ])
167
+ pipeline.fit(queries, labels)
168
+ return pipeline
169
+
170
+
171
+ def _train_fasttext(queries: list[str], labels: list[str]):
172
+ """Train FastText classifier. Requires fasttext package."""
173
+ try:
174
+ import fasttext
175
+ import tempfile
176
+ import os
177
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f:
178
+ for q, l in zip(queries, labels):
179
+ line = q.replace("\n", " ").strip()
180
+ f.write(f"__label__{l} {line}\n")
181
+ path = f.name
182
+ model = fasttext.train_supervised(path, epoch=25, lr=0.5, wordNgrams=2)
183
+ os.unlink(path)
184
+ return model
185
+ except ImportError:
186
+ raise RuntimeError("FastText not installed: pip install fasttext")
187
+
188
+
189
+ def _train_distilbert(queries: list[str], labels: list[str]) -> dict:
190
+ """DistilBERT zero-shot: creates pipeline (no training). Saves config for inference."""
191
+ try:
192
+ from transformers import pipeline
193
+ pipe = pipeline(
194
+ "zero-shot-classification",
195
+ model="distilbert-base-uncased",
196
+ device=-1,
197
+ )
198
+ return {
199
+ "backend": "distilbert",
200
+ "distilbert_model": "distilbert-base-uncased",
201
+ "intents": INTENTS,
202
+ }
203
+ except Exception as e:
204
+ raise RuntimeError(f"DistilBERT setup failed: {e}")
src/core/metadata_store.py CHANGED
@@ -142,6 +142,148 @@ class MetadataStore:
142
  logger.error(f"MetadataStore insert_book failed: {e}")
143
  return False
144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  def load_books_processed(self): pass
146
  def load_train_data(self): pass
147
 
 
142
  logger.error(f"MetadataStore insert_book failed: {e}")
143
  return False
144
 
145
+ def insert_book_with_fts(self, row: Dict[str, Any]) -> bool:
146
+ """
147
+ Insert a new book into both main table AND FTS5 index.
148
+
149
+ This enables incremental indexing - new books are immediately searchable
150
+ via keyword search without requiring a full index rebuild.
151
+
152
+ Args:
153
+ row: Book data dict with keys: isbn13, title, description, authors, simple_categories, etc.
154
+
155
+ Returns:
156
+ True if successful, False otherwise
157
+ """
158
+ conn = self.connection
159
+ if not conn:
160
+ return False
161
+
162
+ try:
163
+ # 1. Insert into main books table
164
+ if not self.insert_book(row):
165
+ return False
166
+
167
+ # 2. Insert into FTS5 index
168
+ # FTS5 columns: isbn13, title, description, authors, simple_categories
169
+ isbn13 = str(row.get("isbn13", ""))
170
+ title = str(row.get("title", ""))
171
+ description = str(row.get("description", ""))
172
+ authors = str(row.get("authors", ""))
173
+ categories = str(row.get("simple_categories", ""))
174
+
175
+ # Check if FTS5 table exists
176
+ cursor = conn.cursor()
177
+ cursor.execute(
178
+ "SELECT name FROM sqlite_master WHERE type='table' AND name='books_fts'"
179
+ )
180
+ if not cursor.fetchone():
181
+ logger.warning("MetadataStore: FTS5 table 'books_fts' not found. Skipping FTS index.")
182
+ return True # Main insert succeeded, FTS just not available
183
+
184
+ # Insert into FTS5 (use INSERT OR REPLACE to handle updates)
185
+ cursor.execute(
186
+ """
187
+ INSERT OR REPLACE INTO books_fts (isbn13, title, description, authors, simple_categories)
188
+ VALUES (?, ?, ?, ?, ?)
189
+ """,
190
+ (isbn13, title, description, authors, categories)
191
+ )
192
+ conn.commit()
193
+
194
+ logger.info(f"MetadataStore: Inserted book {isbn13} into FTS5 index")
195
+ return True
196
+
197
+ except sqlite3.OperationalError as e:
198
+ # FTS5 might not support OR REPLACE, try without
199
+ if "REPLACE" in str(e):
200
+ try:
201
+ cursor = conn.cursor()
202
+ cursor.execute(
203
+ """
204
+ INSERT INTO books_fts (isbn13, title, description, authors, simple_categories)
205
+ VALUES (?, ?, ?, ?, ?)
206
+ """,
207
+ (isbn13, title, description, authors, categories)
208
+ )
209
+ conn.commit()
210
+ return True
211
+ except Exception as inner_e:
212
+ logger.error(f"MetadataStore FTS5 insert failed: {inner_e}")
213
+ return False
214
+ logger.error(f"MetadataStore FTS5 insert failed: {e}")
215
+ return False
216
+ except Exception as e:
217
+ logger.error(f"MetadataStore insert_book_with_fts failed: {e}")
218
+ return False
219
+
220
+ def book_exists(self, isbn: str) -> bool:
221
+ """Check if a book with given ISBN exists in the database."""
222
+ isbn = str(isbn).strip().replace(".0", "")
223
+ row = self._query_one(
224
+ "SELECT 1 FROM books WHERE isbn13 = ? OR isbn10 = ? LIMIT 1",
225
+ (isbn, isbn)
226
+ )
227
+ return row is not None
228
+
229
+ def get_newest_book_year(self) -> Optional[int]:
230
+ """Get the publication year of the newest book in the database."""
231
+ conn = self.connection
232
+ if not conn:
233
+ return None
234
+ try:
235
+ cursor = conn.cursor()
236
+ # Try publishedDate column
237
+ cursor.execute(
238
+ "SELECT publishedDate FROM books WHERE publishedDate IS NOT NULL "
239
+ "ORDER BY publishedDate DESC LIMIT 1"
240
+ )
241
+ row = cursor.fetchone()
242
+ if row and row[0]:
243
+ # Extract year from date string
244
+ date_str = str(row[0])
245
+ if len(date_str) >= 4:
246
+ return int(date_str[:4])
247
+ except Exception as e:
248
+ logger.debug(f"get_newest_book_year failed: {e}")
249
+ return None
250
+
251
+ def get_book_count(self) -> int:
252
+ """Get total number of books in the database."""
253
+ conn = self.connection
254
+ if not conn:
255
+ return 0
256
+ try:
257
+ cursor = conn.cursor()
258
+ cursor.execute("SELECT COUNT(*) FROM books")
259
+ row = cursor.fetchone()
260
+ return row[0] if row else 0
261
+ except Exception as e:
262
+ logger.error(f"get_book_count failed: {e}")
263
+ return 0
264
+
265
+ def get_books_by_year_distribution(self) -> Dict[int, int]:
266
+ """Get distribution of books by publication year."""
267
+ conn = self.connection
268
+ if not conn:
269
+ return {}
270
+ try:
271
+ cursor = conn.cursor()
272
+ cursor.execute(
273
+ """
274
+ SELECT SUBSTR(publishedDate, 1, 4) as year, COUNT(*) as count
275
+ FROM books
276
+ WHERE publishedDate IS NOT NULL AND LENGTH(publishedDate) >= 4
277
+ GROUP BY year
278
+ ORDER BY year DESC
279
+ LIMIT 20
280
+ """
281
+ )
282
+ return {int(row[0]): row[1] for row in cursor.fetchall() if row[0].isdigit()}
283
+ except Exception as e:
284
+ logger.debug(f"get_books_by_year_distribution failed: {e}")
285
+ return {}
286
+
287
  def load_books_processed(self): pass
288
  def load_train_data(self): pass
289
 
src/core/router.py CHANGED
@@ -1,74 +1,178 @@
1
  import re
2
- from typing import Dict, Any, List
 
 
3
  from src.utils import setup_logger
4
 
5
  logger = setup_logger(__name__)
6
 
 
7
  class QueryRouter:
8
  """
9
  Intelligent Router for the RAG Pipeline.
10
  Classifies user queries to select the optimal retrieval strategy.
11
-
 
 
 
12
  Strategies:
13
  1. EXACT (ISBN/ID) -> Pure BM25 (High Precision, No Rerank noise).
14
  2. FAST (Keywords) -> Hybrid (RRF), No Rerank (Low Latency).
15
  3. DEEP (Complex) -> Hybrid + Rerank (High Latency, High contextual relevance).
16
- """
17
 
18
- def __init__(self):
19
- # Regex for ISBN-10 and ISBN-13
20
- self.isbn_pattern = re.compile(r'^(?:\d{9}[\dX]|\d{13})$')
 
 
 
 
 
 
 
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  def route(self, query: str) -> Dict[str, Any]:
23
  """
24
  Analyze query and return retrieval parameters.
25
- Returns dict with: 'strategy', 'hybrid_alpha', 'rerank'
 
 
 
 
 
 
 
 
 
26
  """
27
  cleaned_query = query.strip()
28
  words = cleaned_query.split()
29
-
30
- # 1. Check for ISBN (Exact Match)
31
- # Remove hyphens/spaces for check
32
  normalized = cleaned_query.replace("-", "").replace(" ", "")
33
  if self.isbn_pattern.match(normalized):
34
- logger.info(f"Router: Detected ISBN -> EXACT Strategy ({normalized})")
35
- return {"strategy": "exact", "alpha": 1.0, "rerank": False, "k_final": 5}
36
-
37
- # 2. Check for Temporal Keywords (Freshness Bias)
38
- temporal_keywords = {"new", "newest", "latest", "recent", "modern", "contemporary", "2020", "2021", "2022", "2023", "2024", "2025"}
39
- is_temporal = any(word.lower() in temporal_keywords for word in words)
40
-
41
- # 3. Check for Detail-Oriented Queries (Triggers Small-to-Big)
42
- # These are queries asking about specific plot points, reactions, or hidden details
43
- detail_keywords = {"twist", "ending", "spoiler", "readers", "felt", "cried", "hated", "loved",
44
- "review", "opinion", "think", "unreliable", "narrator", "realize", "find out"}
45
- is_detail = any(word.lower() in detail_keywords for word in words)
46
-
47
- if is_detail:
48
- logger.info(f"Router: Detected Detail Query -> SMALL_TO_BIG Strategy")
49
  return {
50
- "strategy": "small_to_big",
51
- "rerank": False, # Small-to-Big already does precision matching
 
52
  "k_final": 5,
53
- "temporal": is_temporal
 
 
 
54
  }
55
-
56
- # 4. Check for Simple Keyword Search (Short queries)
57
- if len(words) <= 2:
58
- logger.info(f"Router: Detected Keyword -> FAST Strategy (Temporal={is_temporal})")
59
- return {
60
- "strategy": "fast",
61
- "rerank": False, # Skip expensive rerank
62
- "k_final": 5,
63
- "temporal": is_temporal
64
- }
65
-
66
- # 5. Default to Deep Search
67
- logger.info(f"Router: Detected Natural Language -> DEEP Strategy (Temporal={is_temporal})")
68
- return {
69
- "strategy": "deep",
70
- "rerank": True,
71
- "k_final": 10,
72
- "temporal": is_temporal
73
- }
 
 
 
 
 
74
 
 
1
  import re
2
+ from pathlib import Path
3
+ from typing import Dict, Any, List, Optional
4
+
5
  from src.utils import setup_logger
6
 
7
  logger = setup_logger(__name__)
8
 
9
+
10
  class QueryRouter:
11
  """
12
  Intelligent Router for the RAG Pipeline.
13
  Classifies user queries to select the optimal retrieval strategy.
14
+
15
+ Uses model-based intent classifier when available; falls back to rule-based
16
+ heuristics when classifier not trained/loaded.
17
+
18
  Strategies:
19
  1. EXACT (ISBN/ID) -> Pure BM25 (High Precision, No Rerank noise).
20
  2. FAST (Keywords) -> Hybrid (RRF), No Rerank (Low Latency).
21
  3. DEEP (Complex) -> Hybrid + Rerank (High Latency, High contextual relevance).
 
22
 
23
+ Freshness-Aware Routing:
24
+ - Detects queries asking for "new", "latest", or specific years (2024, 2025, etc.)
25
+ - Sets freshness_fallback=True to enable Web Search when local results insufficient
26
+ """
27
+
28
+ # Keywords that indicate user wants fresh/recent content
29
+ # Note: Year numbers are detected dynamically in _detect_freshness()
30
+ FRESHNESS_KEYWORDS = {
31
+ "new", "newest", "latest", "recent", "modern", "contemporary", "current",
32
+ }
33
 
34
+ # Strong freshness indicators (always trigger fallback)
35
+ STRONG_FRESHNESS_KEYWORDS = {
36
+ "newest", "latest",
37
+ }
38
+
39
+ def __init__(self, model_dir: str | Path | None = None):
40
+ self.isbn_pattern = re.compile(r"^(?:\d{9}[\dX]|\d{13})$")
41
+ if model_dir is None:
42
+ from src.config import DATA_DIR
43
+ model_dir = DATA_DIR / "model"
44
+ self.model_dir = Path(model_dir)
45
+ self._classifier = None
46
+
47
+ def _get_classifier(self):
48
+ """Lazy-load intent classifier when first needed."""
49
+ if self._classifier is not None:
50
+ return self._classifier
51
+ try:
52
+ from src.core.intent_classifier import IntentClassifier
53
+ clf = IntentClassifier(self.model_dir / "intent_classifier.pkl")
54
+ if clf.load():
55
+ self._classifier = clf
56
+ except Exception as e:
57
+ logger.debug("Intent classifier not available: %s", e)
58
+ return self._classifier
59
+
60
+ def _detect_freshness(self, words: list) -> tuple[bool, bool, Optional[int]]:
61
+ """
62
+ Detect if query requires fresh content.
63
+
64
+ Returns:
65
+ (is_temporal, freshness_fallback, target_year)
66
+ - is_temporal: Should apply temporal boost to local results
67
+ - freshness_fallback: Should enable Web Search if local results insufficient
68
+ - target_year: Specific year user is looking for (if detected)
69
+ """
70
+ from datetime import datetime
71
+ current_year = datetime.now().year
72
+
73
+ lower_words = {w.lower() for w in words}
74
+
75
+ is_temporal = bool(lower_words & self.FRESHNESS_KEYWORDS)
76
+ freshness_fallback = bool(lower_words & self.STRONG_FRESHNESS_KEYWORDS)
77
+
78
+ # Extract explicit year from query
79
+ target_year = None
80
+ for word in words:
81
+ if word.isdigit() and len(word) == 4:
82
+ year = int(word)
83
+ if 2000 <= year <= 2050:
84
+ target_year = year
85
+ # Recent years (within last 3 years) trigger freshness
86
+ if year >= current_year - 2:
87
+ is_temporal = True
88
+ freshness_fallback = True
89
+ break
90
+
91
+ return is_temporal, freshness_fallback, target_year
92
+
93
+ def _route_by_rules(
94
+ self,
95
+ cleaned_query: str,
96
+ words: list,
97
+ is_temporal: bool,
98
+ freshness_fallback: bool = False,
99
+ target_year: Optional[int] = None
100
+ ) -> Dict[str, Any]:
101
+ """Fallback: rule-based routing (original logic + freshness)."""
102
+ detail_keywords = {
103
+ "twist", "ending", "spoiler", "readers", "felt", "cried", "hated", "loved",
104
+ "review", "opinion", "think", "unreliable", "narrator", "realize", "find out",
105
+ }
106
+
107
+ base_result = {
108
+ "temporal": is_temporal,
109
+ "freshness_fallback": freshness_fallback,
110
+ "freshness_threshold": 3, # Trigger web search if < 3 results
111
+ "target_year": target_year,
112
+ }
113
+
114
+ if any(w.lower() in detail_keywords for w in words):
115
+ logger.info("Router (rules): Detail Query -> SMALL_TO_BIG")
116
+ return {**base_result, "strategy": "small_to_big", "alpha": 0.5, "rerank": False, "k_final": 5}
117
+ if len(words) <= 2:
118
+ logger.info("Router (rules): Keyword -> FAST (Temporal=%s, Freshness=%s)", is_temporal, freshness_fallback)
119
+ return {**base_result, "strategy": "fast", "alpha": 0.5, "rerank": False, "k_final": 5}
120
+ logger.info("Router (rules): Natural Language -> DEEP (Temporal=%s, Freshness=%s)", is_temporal, freshness_fallback)
121
+ return {**base_result, "strategy": "deep", "alpha": 0.5, "rerank": True, "k_final": 10}
122
+
123
  def route(self, query: str) -> Dict[str, Any]:
124
  """
125
  Analyze query and return retrieval parameters.
126
+
127
+ Returns dict with:
128
+ - 'strategy': 'exact' | 'fast' | 'deep' | 'small_to_big'
129
+ - 'alpha': float (hybrid search weight)
130
+ - 'rerank': bool (use cross-encoder reranking)
131
+ - 'k_final': int (number of results)
132
+ - 'temporal': bool (apply temporal boost)
133
+ - 'freshness_fallback': bool (enable web search if local results insufficient)
134
+ - 'freshness_threshold': int (min local results before triggering web search)
135
+ - 'target_year': int | None (specific year user requested)
136
  """
137
  cleaned_query = query.strip()
138
  words = cleaned_query.split()
139
+
140
+ # 1. ISBN: keep regex (deterministic, correct)
 
141
  normalized = cleaned_query.replace("-", "").replace(" ", "")
142
  if self.isbn_pattern.match(normalized):
143
+ logger.info("Router: ISBN -> EXACT (%s)", normalized)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  return {
145
+ "strategy": "exact",
146
+ "alpha": 1.0,
147
+ "rerank": False,
148
  "k_final": 5,
149
+ "temporal": False,
150
+ "freshness_fallback": False,
151
+ "freshness_threshold": 1,
152
+ "target_year": None,
153
  }
154
+
155
+ # 2. Freshness detection (temporal boost + web fallback)
156
+ is_temporal, freshness_fallback, target_year = self._detect_freshness(words)
157
+
158
+ # 3. Model-based vs rule-based intent
159
+ clf = self._get_classifier()
160
+ if clf is not None:
161
+ try:
162
+ intent = clf.predict(cleaned_query)
163
+ logger.info("Router (model): %s -> %s (Freshness=%s)", intent, intent.upper(), freshness_fallback)
164
+ return {
165
+ "strategy": intent,
166
+ "alpha": 0.5,
167
+ "rerank": intent == "deep",
168
+ "k_final": 10 if intent == "deep" else 5,
169
+ "temporal": is_temporal,
170
+ "freshness_fallback": freshness_fallback,
171
+ "freshness_threshold": 3,
172
+ "target_year": target_year,
173
+ }
174
+ except Exception as e:
175
+ logger.warning("Intent classifier failed, falling back to rules: %s", e)
176
+
177
+ return self._route_by_rules(cleaned_query, words, is_temporal, freshness_fallback, target_year)
178
 
src/core/web_search.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Web Search Fallback Module for Data Freshness.
3
+
4
+ Extends the existing Google Books integration in cover_fetcher.py to support
5
+ full book metadata retrieval and keyword-based search for new books.
6
+
7
+ This module provides:
8
+ - search_google_books(): Search books by keyword query
9
+ - fetch_book_by_isbn(): Get complete metadata for a single book
10
+ - is_fresh_enough(): Evaluate if local search results meet freshness requirements
11
+
12
+ API: Google Books API (free tier, no auth required for basic queries)
13
+ Rate Limit: ~1000 requests/day (unofficial), implement conservative caching
14
+ """
15
+ import requests
16
+ from typing import Optional
17
+ from functools import lru_cache
18
+ from datetime import datetime
19
+
20
+ from src.utils import setup_logger
21
+
22
+ logger = setup_logger(__name__)
23
+
24
+ # Google Books API endpoint
25
+ GOOGLE_BOOKS_API = "https://www.googleapis.com/books/v1/volumes"
26
+
27
+ # Request timeout (seconds)
28
+ REQUEST_TIMEOUT = 5
29
+
30
+
31
+ def _parse_isbn_from_identifiers(identifiers: list[dict]) -> tuple[str, str]:
32
+ """
33
+ Extract ISBN-13 and ISBN-10 from Google Books industryIdentifiers.
34
+
35
+ Returns:
36
+ (isbn13, isbn10) - empty strings if not found
37
+ """
38
+ isbn13, isbn10 = "", ""
39
+ for ident in identifiers:
40
+ id_type = ident.get("type", "")
41
+ id_value = ident.get("identifier", "")
42
+ if id_type == "ISBN_13":
43
+ isbn13 = id_value
44
+ elif id_type == "ISBN_10":
45
+ isbn10 = id_value
46
+ return isbn13, isbn10
47
+
48
+
49
+ def _parse_volume_info(volume_info: dict) -> Optional[dict]:
50
+ """
51
+ Parse Google Books volumeInfo into our standard book format.
52
+
53
+ Returns:
54
+ dict with keys: isbn13, isbn10, title, authors, description,
55
+ publishedDate, thumbnail, categories
56
+ None if ISBN is missing (we can't index without ISBN)
57
+ """
58
+ identifiers = volume_info.get("industryIdentifiers", [])
59
+ isbn13, isbn10 = _parse_isbn_from_identifiers(identifiers)
60
+
61
+ # Skip books without ISBN - can't be indexed reliably
62
+ if not isbn13 and not isbn10:
63
+ return None
64
+
65
+ # Use isbn13 as primary, fallback to isbn10
66
+ primary_isbn = isbn13 if isbn13 else isbn10
67
+
68
+ # Extract image links (prefer larger sizes)
69
+ image_links = volume_info.get("imageLinks", {})
70
+ thumbnail = (
71
+ image_links.get("extraLarge") or
72
+ image_links.get("large") or
73
+ image_links.get("medium") or
74
+ image_links.get("small") or
75
+ image_links.get("thumbnail") or
76
+ ""
77
+ )
78
+ # Ensure HTTPS
79
+ if thumbnail.startswith("http://"):
80
+ thumbnail = thumbnail.replace("http://", "https://")
81
+
82
+ # Categories: Google returns list, we join to single string
83
+ categories = volume_info.get("categories", [])
84
+ category_str = categories[0] if categories else "General"
85
+
86
+ return {
87
+ "isbn13": isbn13 or isbn10, # Primary key
88
+ "isbn10": isbn10 or isbn13[:10] if isbn13 else "",
89
+ "title": volume_info.get("title", "Unknown Title"),
90
+ "authors": ", ".join(volume_info.get("authors", ["Unknown"])),
91
+ "description": volume_info.get("description", ""),
92
+ "publishedDate": volume_info.get("publishedDate", ""),
93
+ "thumbnail": thumbnail,
94
+ "simple_categories": category_str,
95
+ "average_rating": volume_info.get("averageRating", 0.0),
96
+ "source": "google_books", # Track data source
97
+ }
98
+
99
+
100
+ def search_google_books(query: str, max_results: int = 10) -> list[dict]:
101
+ """
102
+ Search Google Books by keyword query.
103
+
104
+ Args:
105
+ query: Search query (e.g., "latest sci-fi 2024", "new fantasy novels")
106
+ max_results: Maximum number of results (1-40, Google's limit)
107
+
108
+ Returns:
109
+ List of book dicts in standard format, ordered by relevance
110
+ """
111
+ if not query or not query.strip():
112
+ return []
113
+
114
+ max_results = min(max_results, 40) # Google's API limit
115
+
116
+ try:
117
+ params = {
118
+ "q": query,
119
+ "maxResults": max_results,
120
+ "printType": "books",
121
+ "orderBy": "relevance",
122
+ }
123
+
124
+ response = requests.get(
125
+ GOOGLE_BOOKS_API,
126
+ params=params,
127
+ timeout=REQUEST_TIMEOUT
128
+ )
129
+
130
+ if response.status_code != 200:
131
+ logger.warning(f"Google Books API returned {response.status_code}")
132
+ return []
133
+
134
+ data = response.json()
135
+ total_items = data.get("totalItems", 0)
136
+
137
+ if total_items == 0:
138
+ logger.info(f"No results for query: {query}")
139
+ return []
140
+
141
+ items = data.get("items", [])
142
+ results = []
143
+
144
+ for item in items:
145
+ volume_info = item.get("volumeInfo", {})
146
+ parsed = _parse_volume_info(volume_info)
147
+ if parsed:
148
+ results.append(parsed)
149
+
150
+ logger.info(f"Google Books search '{query}': {len(results)} valid results")
151
+ return results
152
+
153
+ except requests.Timeout:
154
+ logger.warning(f"Google Books API timeout for query: {query}")
155
+ return []
156
+ except requests.RequestException as e:
157
+ logger.error(f"Google Books API request failed: {e}")
158
+ return []
159
+ except Exception as e:
160
+ logger.error(f"Unexpected error in search_google_books: {e}")
161
+ return []
162
+
163
+
164
+ @lru_cache(maxsize=500)
165
+ def fetch_book_by_isbn(isbn: str) -> Optional[dict]:
166
+ """
167
+ Fetch complete book metadata by ISBN from Google Books.
168
+
169
+ Args:
170
+ isbn: ISBN-10 or ISBN-13
171
+
172
+ Returns:
173
+ Book dict in standard format, or None if not found
174
+ """
175
+ if not isbn or not isbn.strip():
176
+ return None
177
+
178
+ isbn = isbn.strip().replace("-", "")
179
+
180
+ try:
181
+ params = {
182
+ "q": f"isbn:{isbn}",
183
+ "maxResults": 1,
184
+ }
185
+
186
+ response = requests.get(
187
+ GOOGLE_BOOKS_API,
188
+ params=params,
189
+ timeout=REQUEST_TIMEOUT
190
+ )
191
+
192
+ if response.status_code != 200:
193
+ return None
194
+
195
+ data = response.json()
196
+ if data.get("totalItems", 0) == 0:
197
+ return None
198
+
199
+ items = data.get("items", [])
200
+ if not items:
201
+ return None
202
+
203
+ volume_info = items[0].get("volumeInfo", {})
204
+ return _parse_volume_info(volume_info)
205
+
206
+ except Exception as e:
207
+ logger.debug(f"fetch_book_by_isbn({isbn}) failed: {e}")
208
+ return None
209
+
210
+
211
+ def search_new_books_by_category(
212
+ category: str,
213
+ year: Optional[int] = None,
214
+ max_results: int = 10
215
+ ) -> list[dict]:
216
+ """
217
+ Search for recently published books in a specific category.
218
+
219
+ Args:
220
+ category: Book category (e.g., "fiction", "science fiction", "mystery")
221
+ year: Filter by publication year (default: current year)
222
+ max_results: Maximum number of results
223
+
224
+ Returns:
225
+ List of book dicts, filtered to specified year or newer
226
+ """
227
+ if year is None:
228
+ year = datetime.now().year
229
+
230
+ # Build query with subject filter
231
+ query = f"subject:{category}"
232
+
233
+ # Get more results than needed, filter by year locally
234
+ raw_results = search_google_books(query, max_results=max_results * 2)
235
+
236
+ filtered = []
237
+ for book in raw_results:
238
+ pub_date = book.get("publishedDate", "")
239
+ if pub_date:
240
+ # Extract year from various date formats (YYYY, YYYY-MM, YYYY-MM-DD)
241
+ try:
242
+ pub_year = int(pub_date[:4])
243
+ if pub_year >= year:
244
+ filtered.append(book)
245
+ except (ValueError, IndexError):
246
+ continue
247
+
248
+ return filtered[:max_results]
249
+
250
+
251
+ def is_fresh_enough(
252
+ results: list[dict],
253
+ threshold: int = 3,
254
+ min_year: Optional[int] = None
255
+ ) -> bool:
256
+ """
257
+ Evaluate if local search results meet freshness requirements.
258
+
259
+ Args:
260
+ results: List of book dicts from local search
261
+ threshold: Minimum number of results required
262
+ min_year: If specified, count only books published >= this year
263
+
264
+ Returns:
265
+ True if results are sufficient, False if web fallback should be triggered
266
+ """
267
+ if len(results) < threshold:
268
+ return False
269
+
270
+ if min_year is None:
271
+ return True
272
+
273
+ # Count books meeting year requirement
274
+ fresh_count = 0
275
+ for book in results:
276
+ pub_date = book.get("publishedDate", "") or book.get("published_date", "")
277
+ if pub_date:
278
+ try:
279
+ pub_year = int(str(pub_date)[:4])
280
+ if pub_year >= min_year:
281
+ fresh_count += 1
282
+ except (ValueError, IndexError):
283
+ continue
284
+
285
+ return fresh_count >= threshold
286
+
287
+
288
+ def extract_year_from_query(query: str) -> Optional[int]:
289
+ """
290
+ Extract year requirement from user query.
291
+
292
+ Examples:
293
+ "books from 2024" -> 2024
294
+ "latest 2025 novels" -> 2025
295
+ "new sci-fi" -> current_year
296
+ "classic mystery" -> None
297
+
298
+ Returns:
299
+ Year as int, or None if no year requirement detected
300
+ """
301
+ import re
302
+
303
+ # Explicit year patterns
304
+ year_patterns = [
305
+ r"\b(202[0-9])\b", # 2020-2029
306
+ r"\b(201[0-9])\b", # 2010-2019
307
+ ]
308
+
309
+ for pattern in year_patterns:
310
+ match = re.search(pattern, query)
311
+ if match:
312
+ return int(match.group(1))
313
+
314
+ # Keywords implying "recent" = current year - 1
315
+ freshness_keywords = {
316
+ "new", "newest", "latest", "recent", "modern", "contemporary", "current"
317
+ }
318
+
319
+ words = set(query.lower().split())
320
+ if words & freshness_keywords:
321
+ return datetime.now().year - 1
322
+
323
+ return None
src/ranking/din.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DIN (Deep Interest Network) for CTR/Ranking.
3
+
4
+ Uses attention over user behavior sequence w.r.t. target item to capture
5
+ user interest. Reuses SASRec item embeddings as initialization when available.
6
+
7
+ Reference: Zhou et al., "Deep Interest Network for Click-Through Rate Prediction" (KDD 2018)
8
+ """
9
+
10
+ import logging
11
+ from pathlib import Path
12
+ from typing import Optional
13
+
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class DIN(nn.Module):
23
+ """
24
+ Deep Interest Network: attention over user history w.r.t. target item.
25
+
26
+ Input:
27
+ - user_hist: [B, max_len] int64, padded behavior sequence (item_ids, 0=pad)
28
+ - target_item: [B] int64, candidate item ids
29
+ - aux_features: [B, num_aux] float32, optional scalar features
30
+
31
+ Output: [B] logits for click/positive probability
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ num_items: int,
37
+ embed_dim: int = 64,
38
+ max_hist_len: int = 50,
39
+ mlp_dims: tuple = (128, 64, 32),
40
+ dropout: float = 0.1,
41
+ num_aux: int = 0,
42
+ pretrained_item_emb: Optional[np.ndarray] = None,
43
+ ):
44
+ super().__init__()
45
+ self.num_items = num_items
46
+ self.embed_dim = embed_dim
47
+ self.max_hist_len = max_hist_len
48
+ self.num_aux = num_aux
49
+
50
+ # Item embedding (1-indexed, 0=pad)
51
+ self.item_emb = nn.Embedding(num_items + 1, embed_dim, padding_idx=0)
52
+
53
+ if pretrained_item_emb is not None:
54
+ self._init_from_pretrained(pretrained_item_emb)
55
+
56
+ # Attention: local activation unit (DIN paper)
57
+ self.attn_fc = nn.Sequential(
58
+ nn.Linear(embed_dim * 4, 36),
59
+ nn.ReLU(),
60
+ nn.Linear(36, 1),
61
+ )
62
+
63
+ # MLP: [user_interest; target_emb; aux?] -> score
64
+ mlp_in = embed_dim * 2 + num_aux
65
+ layers = []
66
+ for d in mlp_dims:
67
+ layers.append(nn.Linear(mlp_in, d))
68
+ layers.append(nn.ReLU())
69
+ layers.append(nn.Dropout(dropout))
70
+ mlp_in = d
71
+ layers.append(nn.Linear(mlp_in, 1))
72
+ self.mlp = nn.Sequential(*layers)
73
+
74
+ def _init_from_pretrained(self, emb: np.ndarray) -> None:
75
+ """Initialize item_emb from SASRec checkpoint."""
76
+ if emb.shape[0] >= self.num_items + 1 and emb.shape[1] == self.embed_dim:
77
+ with torch.no_grad():
78
+ self.item_emb.weight.data[: emb.shape[0]].copy_(torch.from_numpy(emb))
79
+ logger.info("DIN: Initialized item_emb from pretrained (%d x %d)", *emb.shape)
80
+ else:
81
+ logger.warning("DIN: Pretrained shape %s mismatch, skipping init", emb.shape)
82
+
83
+ def forward(
84
+ self,
85
+ user_hist: torch.Tensor,
86
+ target_item: torch.Tensor,
87
+ aux_features: Optional[torch.Tensor] = None,
88
+ ) -> torch.Tensor:
89
+ """
90
+ user_hist: [B, L]
91
+ target_item: [B]
92
+ aux_features: [B, num_aux] or None
93
+ """
94
+ # [B, L, E]
95
+ hist_embs = self.item_emb(user_hist)
96
+ # [B, E]
97
+ target_emb = self.item_emb(target_item)
98
+
99
+ # Attention: local activation
100
+ # [B, L, E] -> expand target to [B, L, E]
101
+ target_expand = target_emb.unsqueeze(1).expand(-1, user_hist.size(1), -1)
102
+ attn_input = torch.cat([
103
+ hist_embs,
104
+ target_expand,
105
+ hist_embs * target_expand,
106
+ hist_embs - target_expand,
107
+ ], dim=-1) # [B, L, 4E]
108
+ attn_scores = self.attn_fc(attn_input).squeeze(-1) # [B, L]
109
+
110
+ # Mask padding (0 = pad)
111
+ mask = (user_hist != 0).float()
112
+ attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
113
+ attn_weights = F.softmax(attn_scores, dim=1) # [B, L]
114
+ # When all zeros (no history), attn_weights can be nan; use mask to zero out
115
+ attn_weights = torch.nan_to_num(attn_weights, nan=0.0)
116
+ attn_weights = attn_weights * mask
117
+ attn_weights = attn_weights / (attn_weights.sum(dim=1, keepdim=True) + 1e-9)
118
+
119
+ # Weighted sum: user interest vector [B, E]
120
+ user_interest = (hist_embs * attn_weights.unsqueeze(-1)).sum(dim=1)
121
+
122
+ # MLP input
123
+ mlp_in = torch.cat([user_interest, target_emb], dim=1)
124
+ if aux_features is not None and self.num_aux > 0 and aux_features.size(-1) == self.num_aux:
125
+ mlp_in = torch.cat([mlp_in, aux_features], dim=1)
126
+
127
+ logits = self.mlp(mlp_in).squeeze(-1)
128
+ return logits
129
+
130
+
131
+ class DINRanker:
132
+ """
133
+ Wrapper for DIN model: load, predict, compatible with RecommendationService.
134
+ """
135
+
136
+ def __init__(
137
+ self,
138
+ data_dir: str = "data/rec",
139
+ model_dir: str = "data/model",
140
+ ):
141
+ self.data_dir = Path(data_dir)
142
+ self.model_dir = Path(model_dir) / "ranking"
143
+ self.model: Optional[DIN] = None
144
+ self.item_map: dict = {}
145
+ self.id_to_item: dict = {}
146
+ self.user_sequences: dict = {}
147
+ self.max_hist_len = 50
148
+ self.aux_feature_names: list = []
149
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
150
+ if torch.backends.mps.is_available():
151
+ self.device = torch.device("mps")
152
+
153
+ def load(self) -> bool:
154
+ """Load trained DIN and aux data."""
155
+ import pickle
156
+
157
+ model_path = self.model_dir / "din_ranker.pt"
158
+ if not model_path.exists():
159
+ return False
160
+
161
+ try:
162
+ ckpt = torch.load(model_path, map_location=self.device, weights_only=False)
163
+ self.model = ckpt["model"]
164
+ self.model.to(self.device)
165
+ self.model.eval()
166
+ self.item_map = ckpt.get("item_map", {})
167
+ self.id_to_item = {v: k for k, v in self.item_map.items()}
168
+ self.max_hist_len = ckpt.get("max_hist_len", 50)
169
+ self.aux_feature_names = ckpt.get("aux_feature_names", [])
170
+
171
+ with open(self.data_dir / "user_sequences.pkl", "rb") as f:
172
+ seqs = pickle.load(f)
173
+ # user_sequences: user_id -> list of item_ids (int)
174
+ self.user_sequences = seqs
175
+
176
+ logger.info("DIN ranker loaded from %s", model_path)
177
+ return True
178
+ except Exception as e:
179
+ logger.error("Failed to load DIN ranker: %s", e)
180
+ return False
181
+
182
+ def predict(
183
+ self,
184
+ user_id: str,
185
+ candidate_items: list[str],
186
+ aux_features: Optional[np.ndarray] = None,
187
+ ) -> np.ndarray:
188
+ """Predict scores for (user_id, candidate_items). Returns [len(candidate_items)]."""
189
+ if self.model is None:
190
+ self.load()
191
+ if self.model is None:
192
+ return np.zeros(len(candidate_items))
193
+
194
+ hist = self.user_sequences.get(user_id, [])
195
+ if hist and isinstance(hist[0], str):
196
+ hist = [self.item_map.get(h, 0) for h in hist]
197
+ hist = hist[-self.max_hist_len:]
198
+ padded = np.zeros(self.max_hist_len, dtype=np.int64)
199
+ padded[: len(hist)] = hist
200
+
201
+ target_ids = np.array([self.item_map.get(str(it), 0) for it in candidate_items], dtype=np.int64)
202
+
203
+ hist_t = torch.LongTensor(padded).unsqueeze(0).expand(len(candidate_items), -1).to(self.device)
204
+ target_t = torch.LongTensor(target_ids).to(self.device)
205
+
206
+ aux_t = None
207
+ if aux_features is not None and aux_features.size > 0:
208
+ aux_t = torch.from_numpy(aux_features.astype(np.float32)).to(self.device)
209
+
210
+ with torch.no_grad():
211
+ logits = self.model(hist_t, target_t, aux_t)
212
+ return logits.cpu().numpy()
src/recommender.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import List, Dict, Any
2
  from src.vector_db import VectorDB
3
  from src.config import TOP_K_INITIAL, TOP_K_FINAL, DATA_DIR
4
  from src.cache import CacheManager
@@ -136,13 +136,108 @@ class BookRecommender:
136
  "emotions": emotions,
137
  "review_highlights": highlights,
138
  "persona_summary": "",
139
- "average_rating": float(meta.get("average_rating", 0.0))
 
140
  })
141
 
142
  if len(results) >= TOP_K_FINAL:
143
  break
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
  return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
  def get_categories(self) -> List[str]:
148
  """Get unique book categories from SQLite."""
@@ -152,20 +247,47 @@ class BookRecommender:
152
  """Get available emotional tones."""
153
  return ["All", "Happy", "Sad", "Fear", "Anger", "Surprise"]
154
 
155
- def add_new_book(self, isbn: str, title: str, author: str, description: str, category: str = "General", thumbnail: str = None) -> Any:
 
 
 
 
 
 
 
 
 
156
  """
157
- Add a new book to the system: CSV, Memory, and Vector DB.
158
- Returns the new book dictionary if successful, None otherwise.
 
 
 
 
 
 
 
 
 
 
 
159
  """
160
  try:
161
  import pandas as pd
162
 
 
 
 
 
 
 
 
163
  # 1. Update Persistent Storage (CSV)
164
  csv_path = DATA_DIR / "books_processed.csv"
165
 
166
  # Define new row with all expected columns
167
  new_row = {
168
- "isbn13": isbn,
169
  "title": title,
170
  "authors": author,
171
  "description": description,
@@ -174,13 +296,10 @@ class BookRecommender:
174
  "average_rating": 0.0,
175
  "joy": 0.0, "sadness": 0.0, "fear": 0.0, "anger": 0.0, "surprise": 0.0,
176
  "tags": "", "review_highlights": "",
177
- "isbn10": str(isbn)[:10] # Approximation
 
 
178
  }
179
-
180
- isbn_s = str(isbn)
181
- if metadata_store.get_book_metadata(isbn_s):
182
- logger.warning(f"Book {isbn} already exists. Skipping add.")
183
- return None
184
 
185
  # Append to CSV
186
  if csv_path.exists():
@@ -191,19 +310,20 @@ class BookRecommender:
191
  # Filter/Order new_row to match CSV structure
192
  ordered_row = {}
193
  for col in csv_columns:
194
- ordered_row[col] = new_row.get(col, "") # Default to empty string if missing
195
 
196
  # Append to CSV
197
  pd.DataFrame([ordered_row]).to_csv(csv_path, mode='a', header=False, index=False)
198
  else:
199
- pd.DataFrame([new_row]).to_csv(csv_path, index=False)
200
 
201
  new_row["large_thumbnail"] = new_row["thumbnail"]
 
202
 
203
- # 3. Insert into SQLite (zero-RAM mode)
204
- metadata_store.insert_book(new_row)
205
 
206
- # 4. Update Vector DB (Chroma + BM25)
207
  self.vector_db.add_book(new_row)
208
 
209
  logger.info(f"Successfully added book {isbn}: {title}")
 
1
+ from typing import List, Dict, Any, Optional
2
  from src.vector_db import VectorDB
3
  from src.config import TOP_K_INITIAL, TOP_K_FINAL, DATA_DIR
4
  from src.cache import CacheManager
 
136
  "emotions": emotions,
137
  "review_highlights": highlights,
138
  "persona_summary": "",
139
+ "average_rating": float(meta.get("average_rating", 0.0)),
140
+ "source": "local", # Track data source
141
  })
142
 
143
  if len(results) >= TOP_K_FINAL:
144
  break
145
+
146
+ # 3. Web Search Fallback (Freshness-Aware)
147
+ # Triggered when: freshness_fallback=True AND local results < threshold
148
+ if decision.get("freshness_fallback", False):
149
+ threshold = decision.get("freshness_threshold", 3)
150
+ if len(results) < threshold:
151
+ web_results = self._fetch_from_web(query, TOP_K_FINAL - len(results), category)
152
+ results.extend(web_results)
153
+ logger.info(f"Web fallback added {len(web_results)} books")
154
+
155
+ # Cache the results
156
+ if results:
157
+ self.cache.set(cache_key, results)
158
 
159
  return results
160
+
161
+ def _fetch_from_web(
162
+ self,
163
+ query: str,
164
+ max_results: int,
165
+ category: str = "All"
166
+ ) -> List[Dict[str, Any]]:
167
+ """
168
+ Fetch books from Google Books API when local results are insufficient.
169
+ Auto-persists discovered books to local database for future queries.
170
+
171
+ Args:
172
+ query: User's search query
173
+ max_results: Maximum number of results to fetch
174
+ category: Category filter (not applied to web search, used for filtering results)
175
+
176
+ Returns:
177
+ List of formatted book dicts ready for response
178
+ """
179
+ try:
180
+ from src.core.web_search import search_google_books
181
+ except ImportError:
182
+ logger.warning("Web search module not available")
183
+ return []
184
+
185
+ results = []
186
+
187
+ try:
188
+ web_books = search_google_books(query, max_results=max_results * 2)
189
+
190
+ for book in web_books:
191
+ isbn = book.get("isbn13", "")
192
+ if not isbn:
193
+ continue
194
+
195
+ # Skip if already in local database
196
+ if metadata_store.book_exists(isbn):
197
+ continue
198
+
199
+ # Category filter (if specified)
200
+ if category and category != "All":
201
+ book_cat = book.get("simple_categories", "")
202
+ if category.lower() not in book_cat.lower():
203
+ continue
204
+
205
+ # Auto-persist to local database
206
+ added = self.add_new_book(
207
+ isbn=isbn,
208
+ title=book.get("title", ""),
209
+ author=book.get("authors", "Unknown"),
210
+ description=book.get("description", ""),
211
+ category=book.get("simple_categories", "General"),
212
+ thumbnail=book.get("thumbnail"),
213
+ published_date=book.get("publishedDate", ""),
214
+ )
215
+
216
+ if added:
217
+ results.append({
218
+ "isbn": isbn,
219
+ "title": book.get("title", ""),
220
+ "authors": book.get("authors", "Unknown"),
221
+ "description": book.get("description", ""),
222
+ "thumbnail": book.get("thumbnail", ""),
223
+ "caption": f"{book.get('title', '')} by {book.get('authors', 'Unknown')}",
224
+ "tags": [],
225
+ "emotions": {"joy": 0.0, "sadness": 0.0, "fear": 0.0, "anger": 0.0, "surprise": 0.0},
226
+ "review_highlights": [],
227
+ "persona_summary": "",
228
+ "average_rating": float(book.get("average_rating", 0.0)),
229
+ "source": "google_books", # Track data source
230
+ })
231
+
232
+ if len(results) >= max_results:
233
+ break
234
+
235
+ logger.info(f"Web fallback: Found and persisted {len(results)} new books")
236
+ return results
237
+
238
+ except Exception as e:
239
+ logger.error(f"Web fallback failed: {e}")
240
+ return []
241
 
242
  def get_categories(self) -> List[str]:
243
  """Get unique book categories from SQLite."""
 
247
  """Get available emotional tones."""
248
  return ["All", "Happy", "Sad", "Fear", "Anger", "Surprise"]
249
 
250
+ def add_new_book(
251
+ self,
252
+ isbn: str,
253
+ title: str,
254
+ author: str,
255
+ description: str,
256
+ category: str = "General",
257
+ thumbnail: Optional[str] = None,
258
+ published_date: Optional[str] = None,
259
+ ) -> Optional[Dict[str, Any]]:
260
  """
261
+ Add a new book to the system: CSV, SQLite (with FTS5), and ChromaDB.
262
+
263
+ Args:
264
+ isbn: ISBN-13 or ISBN-10
265
+ title: Book title
266
+ author: Author name(s)
267
+ description: Book description
268
+ category: Book category
269
+ thumbnail: Cover image URL
270
+ published_date: Publication date (YYYY, YYYY-MM, or YYYY-MM-DD)
271
+
272
+ Returns:
273
+ New book dictionary if successful, None otherwise
274
  """
275
  try:
276
  import pandas as pd
277
 
278
+ isbn_s = str(isbn).strip()
279
+
280
+ # Check if already exists
281
+ if metadata_store.book_exists(isbn_s):
282
+ logger.debug(f"Book {isbn} already exists. Skipping add.")
283
+ return None
284
+
285
  # 1. Update Persistent Storage (CSV)
286
  csv_path = DATA_DIR / "books_processed.csv"
287
 
288
  # Define new row with all expected columns
289
  new_row = {
290
+ "isbn13": isbn_s,
291
  "title": title,
292
  "authors": author,
293
  "description": description,
 
296
  "average_rating": 0.0,
297
  "joy": 0.0, "sadness": 0.0, "fear": 0.0, "anger": 0.0, "surprise": 0.0,
298
  "tags": "", "review_highlights": "",
299
+ "isbn10": isbn_s[:10] if len(isbn_s) >= 10 else isbn_s,
300
+ "publishedDate": published_date or "",
301
+ "source": "google_books", # Track data source
302
  }
 
 
 
 
 
303
 
304
  # Append to CSV
305
  if csv_path.exists():
 
310
  # Filter/Order new_row to match CSV structure
311
  ordered_row = {}
312
  for col in csv_columns:
313
+ ordered_row[col] = new_row.get(col, "")
314
 
315
  # Append to CSV
316
  pd.DataFrame([ordered_row]).to_csv(csv_path, mode='a', header=False, index=False)
317
  else:
318
+ pd.DataFrame([new_row]).to_csv(csv_path, index=False)
319
 
320
  new_row["large_thumbnail"] = new_row["thumbnail"]
321
+ new_row["image"] = new_row["thumbnail"]
322
 
323
+ # 2. Insert into SQLite with FTS5 (incremental indexing)
324
+ metadata_store.insert_book_with_fts(new_row)
325
 
326
+ # 3. Update Vector DB (ChromaDB)
327
  self.vector_db.add_book(new_row)
328
 
329
  logger.info(f"Successfully added book {isbn}: {title}")
src/services/recommend_service.py CHANGED
@@ -7,10 +7,12 @@ from pathlib import Path
7
  from src.recall.fusion import RecallFusion
8
  from src.ranking.features import FeatureEngineer
9
  from src.ranking.explainer import RankingExplainer
 
10
  from src.utils import setup_logger
11
 
12
  logger = setup_logger(__name__)
13
 
 
14
  class RecommendationService:
15
  def __init__(self, data_dir='data/rec', model_dir='data/model'):
16
  self.data_dir = Path(data_dir)
@@ -21,6 +23,8 @@ class RecommendationService:
21
 
22
  self.ranker = None
23
  self.ranker_loaded = False
 
 
24
  self.xgb_ranker = None
25
  self.meta_model = None
26
  self.use_stacking = False
@@ -34,7 +38,14 @@ class RecommendationService:
34
  self.fusion.load_models()
35
  self.fe.load_base_data()
36
 
37
- # Load Ranker (LightGBM)
 
 
 
 
 
 
 
38
  ranker_path = self.model_dir / 'ranking/lgbm_ranker.txt'
39
  if ranker_path.exists():
40
  self.ranker = lgb.Booster(model_file=str(ranker_path))
@@ -119,29 +130,34 @@ class RecommendationService:
119
  candidate_items = [item for item, score in candidates]
120
 
121
  # 2. Ranking
122
- if self.ranker_loaded:
123
- # Filter candidates first
124
- valid_candidates = [item for item in candidate_items if item not in fav_isbns]
125
-
126
- if not valid_candidates:
127
- return []
128
 
129
- # Batch Feature Generation (Optimized)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  X_df = self.fe.generate_features_batch(user_id, valid_candidates)
131
-
132
- # Align features to match model
133
  model_features = self.ranker.feature_name()
134
  for col in model_features:
135
  if col not in X_df.columns:
136
  X_df[col] = 0
137
  X_df = X_df[model_features]
138
 
139
- # Predict
140
  if self.use_stacking and self.xgb_ranker is not None and self.meta_model is not None:
141
- # Stacking: Level-1 predictions -> Level-2 meta-learner
142
  lgb_scores = self.ranker.predict(X_df)
143
-
144
- # Check if XGB Ranker is a raw Booster or Sklearn Estimator
145
  if isinstance(self.xgb_ranker, xgb.Booster):
146
  dtest = xgb.DMatrix(X_df)
147
  xgb_scores = self.xgb_ranker.predict(dtest)
@@ -150,24 +166,19 @@ class RecommendationService:
150
  meta_features = np.column_stack([lgb_scores, xgb_scores])
151
  scores = self.meta_model.predict_proba(meta_features)[:, 1]
152
  else:
153
- # Fallback: LightGBM only (backward compatible)
154
  scores = self.ranker.predict(X_df)
155
 
156
- # Compute SHAP explanations (V2.7)
157
  explanations_list = []
158
  if self.explainer is not None:
159
  try:
160
  explanations_list = self.explainer.explain(X_df, top_k=3)
161
  except Exception as e:
162
- logger.warning(f"SHAP explanation failed: {e}")
163
  explanations_list = [[] for _ in valid_candidates]
164
  else:
165
  explanations_list = [[] for _ in valid_candidates]
166
 
167
- # Combine with explanations
168
  final_scores = list(zip(valid_candidates, scores, explanations_list))
169
  final_scores.sort(key=lambda x: x[1], reverse=True)
170
-
171
  else:
172
  # Fallback to recall scores, but filter
173
  final_scores = []
 
7
  from src.recall.fusion import RecallFusion
8
  from src.ranking.features import FeatureEngineer
9
  from src.ranking.explainer import RankingExplainer
10
+ from src.ranking.din import DINRanker
11
  from src.utils import setup_logger
12
 
13
  logger = setup_logger(__name__)
14
 
15
+
16
  class RecommendationService:
17
  def __init__(self, data_dir='data/rec', model_dir='data/model'):
18
  self.data_dir = Path(data_dir)
 
23
 
24
  self.ranker = None
25
  self.ranker_loaded = False
26
+ self.din_ranker = DINRanker(str(data_dir), str(model_dir))
27
+ self.din_ranker_loaded = False
28
  self.xgb_ranker = None
29
  self.meta_model = None
30
  self.use_stacking = False
 
38
  self.fusion.load_models()
39
  self.fe.load_base_data()
40
 
41
+ # Prefer DIN ranker when available (deep model)
42
+ din_path = self.model_dir / 'ranking/din_ranker.pt'
43
+ if din_path.exists():
44
+ if self.din_ranker.load():
45
+ self.din_ranker_loaded = True
46
+ logger.info("DIN ranker loaded — using deep model for ranking")
47
+
48
+ # Load LGBM ranker (fallback when DIN not available)
49
  ranker_path = self.model_dir / 'ranking/lgbm_ranker.txt'
50
  if ranker_path.exists():
51
  self.ranker = lgb.Booster(model_file=str(ranker_path))
 
130
  candidate_items = [item for item, score in candidates]
131
 
132
  # 2. Ranking
133
+ valid_candidates = [item for item in candidate_items if item not in fav_isbns]
134
+ if not valid_candidates:
135
+ return []
 
 
 
136
 
137
+ if self.din_ranker_loaded:
138
+ # DIN: deep model; optional aux features from FeatureEngineer
139
+ aux_arr = None
140
+ if self.din_ranker.aux_feature_names:
141
+ X_df = self.fe.generate_features_batch(user_id, valid_candidates)
142
+ for col in self.din_ranker.aux_feature_names:
143
+ if col not in X_df.columns:
144
+ X_df[col] = 0
145
+ aux_arr = X_df[self.din_ranker.aux_feature_names].values.astype(np.float32)
146
+ scores = self.din_ranker.predict(user_id, valid_candidates, aux_arr)
147
+ explanations_list = [[] for _ in valid_candidates]
148
+ final_scores = list(zip(valid_candidates, scores, explanations_list))
149
+ final_scores.sort(key=lambda x: x[1], reverse=True)
150
+ elif self.ranker_loaded:
151
+ # LGBM / stacking path
152
  X_df = self.fe.generate_features_batch(user_id, valid_candidates)
 
 
153
  model_features = self.ranker.feature_name()
154
  for col in model_features:
155
  if col not in X_df.columns:
156
  X_df[col] = 0
157
  X_df = X_df[model_features]
158
 
 
159
  if self.use_stacking and self.xgb_ranker is not None and self.meta_model is not None:
 
160
  lgb_scores = self.ranker.predict(X_df)
 
 
161
  if isinstance(self.xgb_ranker, xgb.Booster):
162
  dtest = xgb.DMatrix(X_df)
163
  xgb_scores = self.xgb_ranker.predict(dtest)
 
166
  meta_features = np.column_stack([lgb_scores, xgb_scores])
167
  scores = self.meta_model.predict_proba(meta_features)[:, 1]
168
  else:
 
169
  scores = self.ranker.predict(X_df)
170
 
 
171
  explanations_list = []
172
  if self.explainer is not None:
173
  try:
174
  explanations_list = self.explainer.explain(X_df, top_k=3)
175
  except Exception as e:
 
176
  explanations_list = [[] for _ in valid_candidates]
177
  else:
178
  explanations_list = [[] for _ in valid_candidates]
179
 
 
180
  final_scores = list(zip(valid_candidates, scores, explanations_list))
181
  final_scores.sort(key=lambda x: x[1], reverse=True)
 
182
  else:
183
  # Fallback to recall scores, but filter
184
  final_scores = []
src/vector_db.py CHANGED
@@ -321,7 +321,14 @@ class VectorDB:
321
 
322
  def add_book(self, book_data: dict):
323
  """
324
- Dynamically add a new book to the vector database and update indices.
 
 
 
 
 
 
 
325
  """
326
  from langchain_core.documents import Document
327
 
@@ -330,7 +337,7 @@ class VectorDB:
330
  author = book_data.get("authors", "")
331
  description = book_data.get("description", "")
332
 
333
- # 1. Add to Chroma
334
  content = f"Title: {title}\nAuthor: {author}\nDescription: {description}\nISBN: {isbn}"
335
  doc = Document(
336
  page_content=content,
@@ -346,7 +353,4 @@ class VectorDB:
346
  if self.db:
347
  self.db.add_documents([doc])
348
  logger.info(f"Added book {isbn} to ChromaDB")
349
-
350
- if hasattr(self, 'fts_enabled') and self.fts_enabled:
351
- logger.info("Note: FTS5 database updates are not implemented in add_book yet.")
352
 
 
321
 
322
  def add_book(self, book_data: dict):
323
  """
324
+ Dynamically add a new book to the ChromaDB vector index.
325
+
326
+ Note: FTS5 incremental updates are handled separately via
327
+ metadata_store.insert_book_with_fts() called from BookRecommender.add_new_book().
328
+ This method only handles the dense vector index.
329
+
330
+ Args:
331
+ book_data: Dict with isbn13, title, authors, description, etc.
332
  """
333
  from langchain_core.documents import Document
334
 
 
337
  author = book_data.get("authors", "")
338
  description = book_data.get("description", "")
339
 
340
+ # Add to ChromaDB (dense vector index)
341
  content = f"Title: {title}\nAuthor: {author}\nDescription: {description}\nISBN: {isbn}"
342
  doc = Document(
343
  page_content=content,
 
353
  if self.db:
354
  self.db.add_documents([doc])
355
  logger.info(f"Added book {isbn} to ChromaDB")
 
 
 
356