Spaces:
Running
Running
chore: update requirements and documentation for intent classifier and RAG evaluation
Browse files- data/rag_golden.README.md +32 -0
- docs/TECHNICAL_REPORT.md +8 -2
- docs/interview_guide.md +46 -17
- requirements.txt +3 -0
- scripts/data/build_sequences.py +41 -82
- scripts/data/fetch_new_books.py +322 -0
- scripts/data/validate_data.py +19 -1
- scripts/model/evaluate_rag.py +169 -0
- scripts/model/train_din_ranker.py +264 -0
- scripts/model/train_intent_router.py +156 -0
- scripts/model/train_ranker.py +5 -8
- scripts/run_pipeline.py +28 -1
- src/core/freshness_monitor.py +231 -0
- src/core/intent_classifier.py +204 -0
- src/core/metadata_store.py +142 -0
- src/core/router.py +151 -47
- src/core/web_search.py +323 -0
- src/ranking/din.py +212 -0
- src/recommender.py +137 -17
- src/services/recommend_service.py +30 -19
- src/vector_db.py +9 -5
data/rag_golden.README.md
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# RAG Golden Test Set
|
| 2 |
+
|
| 3 |
+
Human-annotated Query-Book pairs for quantitative RAG evaluation.
|
| 4 |
+
|
| 5 |
+
## Format
|
| 6 |
+
|
| 7 |
+
CSV with columns: `query`, `isbn`, `relevance`, `notes`
|
| 8 |
+
|
| 9 |
+
- **query**: User search string (e.g., "Harry Potter", "0060959479", "books about AI")
|
| 10 |
+
- **isbn**: Expected relevant book ISBN (from your catalog)
|
| 11 |
+
- **relevance**: 1 = relevant (filter rows with relevance=1)
|
| 12 |
+
- **notes**: Optional annotation note
|
| 13 |
+
|
| 14 |
+
Multiple rows per query = multiple relevant books (Recall@K counts all).
|
| 15 |
+
|
| 16 |
+
## Usage
|
| 17 |
+
|
| 18 |
+
```bash
|
| 19 |
+
# Copy example and extend with your catalog ISBNs
|
| 20 |
+
cp data/rag_golden.example.csv data/rag_golden.csv
|
| 21 |
+
|
| 22 |
+
# Run evaluation
|
| 23 |
+
python scripts/model/evaluate_rag.py --golden data/rag_golden.csv
|
| 24 |
+
```
|
| 25 |
+
|
| 26 |
+
## Metrics
|
| 27 |
+
|
| 28 |
+
- **Accuracy@K**: Fraction of queries with at least one relevant book in top-K
|
| 29 |
+
- **Recall@K**: Fraction of relevant books (across all queries) found in top-K
|
| 30 |
+
- **MRR@K**: Mean reciprocal rank of first relevant hit
|
| 31 |
+
|
| 32 |
+
Target: 500+ pairs for production-quality evaluation.
|
docs/TECHNICAL_REPORT.md
CHANGED
|
@@ -217,6 +217,8 @@ Architecture: Self-Attentive Sequential Recommendation with Transformer blocks
|
|
| 217 |
- Training: 30 epochs, 64-dim embeddings, BCE loss with negative sampling
|
| 218 |
- Dual use: (1) ranking feature via `sasrec_score`, (2) independent recall channel via embedding dot-product
|
| 219 |
|
|
|
|
|
|
|
| 220 |
### 4.3 LGBMRanker (LambdaRank) + Model Stacking
|
| 221 |
|
| 222 |
Replaced XGBoost binary classifier with LightGBM LambdaRank that directly optimizes NDCG. In v2.6.0, a Stacking ensemble (LGBMRanker + XGBClassifier → LogisticRegression meta-learner) further improves ranking robustness.
|
|
@@ -226,6 +228,8 @@ Replaced XGBoost binary classifier with LightGBM LambdaRank that directly optimi
|
|
| 226 |
- 20K users sampled from 168K validation set for training speed
|
| 227 |
- 4× negative ratio per positive sample
|
| 228 |
|
|
|
|
|
|
|
| 229 |
**17 features** in 5 groups:
|
| 230 |
- User statistics: u_cnt, u_mean, u_std
|
| 231 |
- Item statistics: i_cnt, i_mean, i_std
|
|
@@ -264,10 +268,12 @@ Feature importance (v2.6.0 LGBMRanker, representative subset):
|
|
| 264 |
|--------|------------------------|-------------|
|
| 265 |
| ISBN Recall | 0% | 100% |
|
| 266 |
| Keyword Precision | Low | High (BM25 boost) |
|
| 267 |
-
| Detail Query Recall | 0% |
|
| 268 |
| Avg Latency | 100ms | 300-800ms |
|
| 269 |
| Chat Context Limit | ~10 turns | Extended via compression (no formal limit) |
|
| 270 |
|
|
|
|
|
|
|
| 271 |
### 5.2 Latency Benchmarks
|
| 272 |
|
| 273 |
| Operation | P50 Latency (Warm) | P95 Latency (Warm) |
|
|
@@ -371,7 +377,7 @@ src/
|
|
| 371 |
|
| 372 |
- **Single-dataset evaluation**: All RecSys metrics are on Amazon Books 200K; no cross-domain or external validation.
|
| 373 |
- **Rule-based router**: Intent classification uses heuristics (e.g., `len(words) <= 2` for keyword); may not generalize to other domains.
|
| 374 |
-
- **RAG evaluation**:
|
| 375 |
- **Protocol sensitivity**: RecSys metrics can vary with evaluation protocol (e.g., ISBN-only vs title-relaxed matching); see [Experiment Archive](experiments/experiment_archive.md) for discussion.
|
| 376 |
|
| 377 |
---
|
|
|
|
| 217 |
- Training: 30 epochs, 64-dim embeddings, BCE loss with negative sampling
|
| 218 |
- Dual use: (1) ranking feature via `sasrec_score`, (2) independent recall channel via embedding dot-product
|
| 219 |
|
| 220 |
+
**Time-split (no leakage)**: SASRec is trained on `train.csv` only. `user_seq_emb` and `sas_item_emb` are computed from train-only sequences. When Ranking uses `sasrec_score` for val samples, the user's history contains only train interactions—never val/test. `build_sequences.py` and SASRec/YoutubeDNN all use train-only.
|
| 221 |
+
|
| 222 |
### 4.3 LGBMRanker (LambdaRank) + Model Stacking
|
| 223 |
|
| 224 |
Replaced XGBoost binary classifier with LightGBM LambdaRank that directly optimizes NDCG. In v2.6.0, a Stacking ensemble (LGBMRanker + XGBClassifier → LogisticRegression meta-learner) further improves ranking robustness.
|
|
|
|
| 228 |
- 20K users sampled from 168K validation set for training speed
|
| 229 |
- 4× negative ratio per positive sample
|
| 230 |
|
| 231 |
+
**Feature consistency**: Recall models (SASRec, ItemCF, etc.) are trained on train.csv. Ranking labels come from val.csv. Features like `sasrec_score` use train-only embeddings. Pipeline order: `split_rec_data` → `build_sequences` (train-only) → recall models (train) → ranker (val).
|
| 232 |
+
|
| 233 |
**17 features** in 5 groups:
|
| 234 |
- User statistics: u_cnt, u_mean, u_std
|
| 235 |
- Item statistics: i_cnt, i_mean, i_std
|
|
|
|
| 268 |
|--------|------------------------|-------------|
|
| 269 |
| ISBN Recall | 0% | 100% |
|
| 270 |
| Keyword Precision | Low | High (BM25 boost) |
|
| 271 |
+
| Detail Query Recall | 0% | Golden Test Set (Accuracy@K, Recall@K, MRR@K) |
|
| 272 |
| Avg Latency | 100ms | 300-800ms |
|
| 273 |
| Chat Context Limit | ~10 turns | Extended via compression (no formal limit) |
|
| 274 |
|
| 275 |
+
**Golden Test Set**: Human-annotated Query-Book pairs (`data/rag_golden.csv`) replace curated examples. Run `python scripts/model/evaluate_rag.py` for Accuracy@K, Recall@K, MRR@K. Extend with ~500+ pairs for production.
|
| 276 |
+
|
| 277 |
### 5.2 Latency Benchmarks
|
| 278 |
|
| 279 |
| Operation | P50 Latency (Warm) | P95 Latency (Warm) |
|
|
|
|
| 377 |
|
| 378 |
- **Single-dataset evaluation**: All RecSys metrics are on Amazon Books 200K; no cross-domain or external validation.
|
| 379 |
- **Rule-based router**: Intent classification uses heuristics (e.g., `len(words) <= 2` for keyword); may not generalize to other domains.
|
| 380 |
+
- **RAG evaluation**: Use Golden Test Set (`data/rag_golden.csv`) for Accuracy@K, Recall@K, MRR@K. Extend to 500+ human-annotated Query-Book pairs for production.
|
| 381 |
- **Protocol sensitivity**: RecSys metrics can vary with evaluation protocol (e.g., ISBN-only vs title-relaxed matching); see [Experiment Archive](experiments/experiment_archive.md) for discussion.
|
| 382 |
|
| 383 |
---
|
docs/interview_guide.md
CHANGED
|
@@ -5,51 +5,80 @@
|
|
| 5 |
## 🌟 核心亮点 (Why this project?)
|
| 6 |
|
| 7 |
### 1. 架构深度 (Architecture Depth)
|
| 8 |
-
|
| 9 |
-
*
|
| 10 |
-
*
|
|
|
|
| 11 |
|
| 12 |
### 2. 工程质量 (Engineering Excellence)
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
|
|
|
| 20 |
|
| 21 |
### 3. 完整性 (Completeness)
|
| 22 |
-
|
| 23 |
-
*
|
|
|
|
| 24 |
|
| 25 |
---
|
| 26 |
|
| 27 |
## 🗣️ 面试话术与 Q&A 策略
|
| 28 |
|
| 29 |
### Q1: 你在项目中遇到的最大困难是什么?怎么解决的?
|
|
|
|
| 30 |
**建议回答**:
|
|
|
|
| 31 |
> “最让我印象深刻的是**系统性能优化**的过程。
|
| 32 |
> 最初版本在处理高并发请求时,推理延迟很高,甚至会阻塞整个服务。
|
| 33 |
> 我通过两个层面解决了这个问题:
|
| 34 |
-
>
|
| 35 |
-
>
|
|
|
|
| 36 |
|
| 37 |
### Q2: 为什么选择 Stacking 融合模型?直接用 LightGBM 不够吗?
|
|
|
|
| 38 |
**建议回答**:
|
|
|
|
| 39 |
> “单一模型往往存在局限性。
|
| 40 |
> LightGBM 擅长处理类别特征和梯度提升,XGBoost 在正则化处理上表现很好。
|
| 41 |
> 通过 Stacking,我使用一个简单的逻辑回归 (Logistic Regression) 作为 Meta-Learner 来学习这两个强模型的输出。
|
| 42 |
> 这不仅能利用不同模型的优势(降低 Bias 和 Variance),还能提高系统的**鲁棒性**。在我的离线实验中,Stacking 相比单一 LightGBM 在 NDCG@10 指标上有明显提升。”
|
| 43 |
|
| 44 |
### Q3: 你的 RAG 系统有什么特别之处?
|
|
|
|
| 45 |
**建议回答**:
|
|
|
|
| 46 |
> “我的 RAG 系统不是简单地 'Retrieve then Generate'。我设计了一个 **Agentic Router**。
|
| 47 |
> 它会先判断用户的意图:如果是搜书号,直接走精确匹配;如果是模糊描述,走语义索引;如果是复杂查询,会触发 Rerank 重排序。
|
| 48 |
> 这种动态策略解决了传统 RAG '查得准就不全,查得全就不准' 的痛点。”
|
| 49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
---
|
| 51 |
|
| 52 |
## 📈 关键指标 (Key Metrics)
|
| 53 |
-
|
| 54 |
-
*
|
| 55 |
-
*
|
|
|
|
|
|
| 5 |
## 🌟 核心亮点 (Why this project?)
|
| 6 |
|
| 7 |
### 1. 架构深度 (Architecture Depth)
|
| 8 |
+
|
| 9 |
+
* **Agentic RAG**: 不仅仅是简单的向量检索,而是引入了**动态路由 (Dynamic Routing)**。系统能根据用户意图(如 ISBN 精确搜索 vs. 模糊语义搜索)自动选择最佳检索策略(BM25, Hybrid, Small-to-Big),展示了对 RAG 系统的精细化控制能力。
|
| 10 |
+
* **Stacking Ensemble (模型融合)**: 在 Ranking 阶段,没有止步于单一模型,而是实现了 **LightGBM + XGBoost + Logistic Regression** 的 Stacking 架构。这体现了对机器学习模型偏差与方差的理解,以及追求极致推荐效果的工程态度。
|
| 11 |
+
* **Vector Database**: 结合 ChromaDB 实现语义搜索,紧跟当前 LLM + Vector Store 的技术热点。
|
| 12 |
|
| 13 |
### 2. 工程质量 (Engineering Excellence)
|
| 14 |
+
|
| 15 |
+
* **性能优化 (Performance Optimization)**:
|
| 16 |
+
* **问题**: 系统在并发场景下出现卡顿,且推理延迟较高。
|
| 17 |
+
* **解决**:
|
| 18 |
+
1. **Async/Await 陷阱**: 发现 FastAPI 的 `async` 路由中运行了 CPU 密集型任务(Pandas 操作),导致 Event Loop 阻塞。即使加上 `await` 也没用,必须去除非 IO 操作的 async 或使用线程池。改为同步 `def` 让 FastAPI自动利用线程池解决。
|
| 19 |
+
2. **向量化重构**: 发现特征生成使用了 Python 原生 `for` 循环。重构为 Numpy/Pandas 的向量化 (Vectorized) 操作,利用 SIMD 指令集优势,将推理速度提升了约 10 倍。
|
| 20 |
+
3. **单例模式**: 引入 `MetadataStore` 单例,避免每次请求重复加载 CSV,显著降低了内存占用和 I/O 开销。
|
| 21 |
+
* **可解释性 (Explainability)**: 集成了 **SHAP (SHapley Additive exPlanations)**。不再是推荐系统的“黑盒”,而是能实时给出“为什么推荐这本书”(例如:因为你喜欢作者 X,或者因为主要读这类书),这是区分初级项目和高级项目的重要特征。
|
| 22 |
|
| 23 |
### 3. 完整性 (Completeness)
|
| 24 |
+
|
| 25 |
+
* **Full Stack**: 前端 (React) + 后端 (FastAPI) + 数据流 (ETL) + 模型训练 (Train Scripts) + 部署 (Docker)。
|
| 26 |
+
* **DevOps**: 包含 Dockerfile 和完整构建脚本,具备生产部署能力。
|
| 27 |
|
| 28 |
---
|
| 29 |
|
| 30 |
## 🗣️ 面试话术与 Q&A 策略
|
| 31 |
|
| 32 |
### Q1: 你在项目中遇到的最大困难是什么?怎么解决的?
|
| 33 |
+
|
| 34 |
**建议回答**:
|
| 35 |
+
|
| 36 |
> “最让我印象深刻的是**系统性能优化**的过程。
|
| 37 |
> 最初版本在处理高并发请求时,推理延迟很高,甚至会阻塞整个服务。
|
| 38 |
> 我通过两个层面解决了这个问题:
|
| 39 |
+
>
|
| 40 |
+
> 1. **架构层**: 我使用 Profiling 工具发现,FastAPI 的 `async` 接口中包含了大量的 Pandas 数据处理逻辑。因为 Python 的 `async` 是单线程协作式的,CPU 密集型任务会直接卡死 Event Loop。我将其重构为利用 FastAPI 线程池的非异步模式,解决了阻塞问题。
|
| 41 |
+
> 2. **代码层**: 我发现特征工程部分原本是用 Python 循环写的。我将其重构为 **Numpy 向量化** 操作,把时间复杂度从 O(N) 的 Python 解释器开销优化到了底层 C 语言级别的矩阵运算,最终将特征生成速度提升了 10 倍以上。”
|
| 42 |
|
| 43 |
### Q2: 为什么选择 Stacking 融合模型?直接用 LightGBM 不够吗?
|
| 44 |
+
|
| 45 |
**建议回答**:
|
| 46 |
+
|
| 47 |
> “单一模型往往存在局限性。
|
| 48 |
> LightGBM 擅长处理类别特征和梯度提升,XGBoost 在正则化处理上表现很好。
|
| 49 |
> 通过 Stacking,我使用一个简单的逻辑回归 (Logistic Regression) 作为 Meta-Learner 来学习这两个强模型的输出。
|
| 50 |
> 这不仅能利用不同模型的优势(降低 Bias 和 Variance),还能提高系统的**鲁棒性**。在我的离线实验中,Stacking 相比单一 LightGBM 在 NDCG@10 指标上有明显提升。”
|
| 51 |
|
| 52 |
### Q3: 你的 RAG 系统有什么特别之处?
|
| 53 |
+
|
| 54 |
**建议回答**:
|
| 55 |
+
|
| 56 |
> “我的 RAG 系统不是简单地 'Retrieve then Generate'。我设计了一个 **Agentic Router**。
|
| 57 |
> 它会先判断用户的意图:如果是搜书号,直接走精确匹配;如果是模糊描述,走语义索引;如果是复杂查询,会触发 Rerank 重排序。
|
| 58 |
> 这种动态策略解决了传统 RAG '查得准就不全,查得全就不准' 的痛点。”
|
| 59 |
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
**Q1. 关于 Swing 算法的物理意义:**
|
| 63 |
+
|
| 64 |
+
> "我看你用了 Swing 召回。你能直观解释一下,为什么 Swing 比传统的 UserCF 更能抗噪声?`1 / (alpha + |I_u ∩ I_v|)` 这个公式里的分母是在惩罚什么样的用户对?"
|
| 65 |
+
> *(考察点:是否真正理解算法原理,还是只是调包。关键在于理解 Swing 惩罚了那些“原本就很相似”的小圈子用户,突出了 serendipity)*
|
| 66 |
+
|
| 67 |
+
**Q2. 关于 RAG 的延迟优化:**
|
| 68 |
+
|
| 69 |
+
> "你的报告提到 Hybrid Search + Rerank 耗时约 800ms。如果我们要把这个系统部署到抖音的搜索框,要求 P99 延迟在 200ms 以内,你会砍掉哪些环节?或者如何通过工程手段优化?"
|
| 70 |
+
> *(考察点:工程思维。答案可能包括:并行请求、向量库量化 HNSW、Rerank 模型蒸馏、缓存热门 Query、异步加载详情等)*
|
| 71 |
+
|
| 72 |
+
**Q3. SASRec 的应用细节:**
|
| 73 |
+
|
| 74 |
+
> "在 `src/model/sasrec.py` 中,你使用了 Transformer。在推理(Inference)阶段,如果用户每点一本书我们都要刷新推荐,SASRec 的计算成本是很高的。你如何缓存用户的 Embedding 状态以避免每次从头计算整个序列?"
|
| 75 |
+
> *(考察点:对深度学习模型线上推理(Inference)优化的理解。关键在于 KV Cache 或者增量计算)*
|
| 76 |
+
>
|
| 77 |
+
|
| 78 |
---
|
| 79 |
|
| 80 |
## 📈 关键指标 (Key Metrics)
|
| 81 |
+
|
| 82 |
+
* **Hit Rate@10**: 0.4545 (v2.6.0, n=2000, Leave-Last-Out)
|
| 83 |
+
* **MRR@5**: 0.2893 (Title-relaxed matching)
|
| 84 |
+
* **Latency**: P99 < 50ms (Personalized Recs)
|
requirements.txt
CHANGED
|
@@ -39,6 +39,9 @@ scikit-learn
|
|
| 39 |
scipy
|
| 40 |
requests
|
| 41 |
|
|
|
|
|
|
|
|
|
|
| 42 |
# LLM Agent & Fine-tuning
|
| 43 |
faiss-cpu
|
| 44 |
diffusers
|
|
|
|
| 39 |
scipy
|
| 40 |
requests
|
| 41 |
|
| 42 |
+
# Intent classifier backends (optional)
|
| 43 |
+
# fasttext # Uncomment for FastText backend: pip install fasttext
|
| 44 |
+
|
| 45 |
# LLM Agent & Fine-tuning
|
| 46 |
faiss-cpu
|
| 47 |
diffusers
|
scripts/data/build_sequences.py
CHANGED
|
@@ -4,24 +4,23 @@ Build User Sequences for Sequential Models (SASRec, YoutubeDNN)
|
|
| 4 |
|
| 5 |
Converts user interaction history into padded sequences for training.
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
Usage:
|
| 8 |
python scripts/data/build_sequences.py
|
| 9 |
|
| 10 |
Input:
|
| 11 |
-
- data/rec/train.csv
|
| 12 |
|
| 13 |
Output:
|
| 14 |
-
- data/rec/user_sequences.pkl (Dict[user_id, List[item_id]])
|
| 15 |
-
- data/rec/item_map.pkl (Dict[isbn, item_id])
|
| 16 |
-
|
| 17 |
-
Notes:
|
| 18 |
-
- Item IDs are 1-indexed (0 is reserved for padding)
|
| 19 |
-
- Sequences are truncated to max_len (default: 50)
|
| 20 |
-
- Test item is excluded from sequences (used for evaluation)
|
| 21 |
"""
|
| 22 |
|
| 23 |
import pandas as pd
|
| 24 |
-
import numpy as np
|
| 25 |
import pickle
|
| 26 |
import logging
|
| 27 |
from pathlib import Path
|
|
@@ -30,83 +29,43 @@ from tqdm import tqdm
|
|
| 30 |
logging.basicConfig(level=logging.INFO)
|
| 31 |
logger = logging.getLogger(__name__)
|
| 32 |
|
| 33 |
-
|
|
|
|
| 34 |
"""
|
| 35 |
-
|
|
|
|
| 36 |
"""
|
| 37 |
-
logger.info("Building user sequences...")
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
pickle.dump(item_map, f)
|
| 56 |
-
|
| 57 |
-
# 2. Group by User and Sort by Time
|
| 58 |
-
# Note: Our data split script ALREADY sorted by time.
|
| 59 |
-
# But let's be safe. We need original timestamps if possible.
|
| 60 |
-
# 'train.csv' doesn't have timestamp column? let me check split_rec_data.
|
| 61 |
-
# Ah, split_rec_data removed it. But rows are ordered.
|
| 62 |
-
# Actually, we can just group by user_id and assume rows are chronological
|
| 63 |
-
# IF we process train -> val -> test order.
|
| 64 |
-
|
| 65 |
-
# Let's reconstruct full history per user
|
| 66 |
-
logger.info("Grouping user history...")
|
| 67 |
-
|
| 68 |
-
# Optimization: processing via dictionary is faster than groupby on large df
|
| 69 |
-
user_history = {} # user_id -> list of item_ids
|
| 70 |
-
|
| 71 |
-
def process_df(df):
|
| 72 |
-
for _, row in tqdm(df.iterrows(), total=len(df), desc="Processing"):
|
| 73 |
-
u = row['user_id']
|
| 74 |
-
item = item_map[row['isbn']]
|
| 75 |
-
if u not in user_history:
|
| 76 |
-
user_history[u] = []
|
| 77 |
-
user_history[u].append(item)
|
| 78 |
-
|
| 79 |
-
# Process in chronological order: Train -> Val -> Test
|
| 80 |
-
# Wait, split_rec_data puts last item in test, 2nd last in val.
|
| 81 |
-
# So correct order is: Train rows + Val row + Test row.
|
| 82 |
-
|
| 83 |
-
# BUT, train.csv has multiple rows per user. They are already sorted by time in split logic.
|
| 84 |
-
process_df(train_df)
|
| 85 |
-
process_df(val_df)
|
| 86 |
-
|
| 87 |
-
# We leave Test item out of the input sequence!
|
| 88 |
-
# Test item is the target for evaluation.
|
| 89 |
-
# For training SASRec, we use (seq[:-1]) -> predict (seq[1:]).
|
| 90 |
-
|
| 91 |
-
# 3. Create Dataset
|
| 92 |
-
# Output:
|
| 93 |
-
# train_seqs: Dict[user_id, list_of_ints]
|
| 94 |
-
|
| 95 |
-
# Pad/Truncate
|
| 96 |
-
final_seqs = {}
|
| 97 |
-
|
| 98 |
-
for u, history in user_history.items():
|
| 99 |
-
# Truncate to max_len
|
| 100 |
-
seq = history[-max_len:]
|
| 101 |
-
final_seqs[u] = seq
|
| 102 |
-
|
| 103 |
-
logger.info(f"Processed {len(final_seqs)} users.")
|
| 104 |
-
|
| 105 |
-
# Save sequences
|
| 106 |
-
with open(f'{data_dir}/user_sequences.pkl', 'wb') as f:
|
| 107 |
pickle.dump(final_seqs, f)
|
| 108 |
-
|
| 109 |
-
logger.info("Sequence data saved.")
|
| 110 |
|
| 111 |
if __name__ == "__main__":
|
| 112 |
build_sequences()
|
|
|
|
| 4 |
|
| 5 |
Converts user interaction history into padded sequences for training.
|
| 6 |
|
| 7 |
+
TIME-SPLIT (strict): Uses train.csv ONLY for sequences and item_map.
|
| 8 |
+
This prevents leakage when Ranking uses SASRec embeddings as features:
|
| 9 |
+
- Val/test samples must not appear in user history when computing sasrec_score.
|
| 10 |
+
- Recall models (SASRec, YoutubeDNN) overwrite these with their own train-only output.
|
| 11 |
+
|
| 12 |
Usage:
|
| 13 |
python scripts/data/build_sequences.py
|
| 14 |
|
| 15 |
Input:
|
| 16 |
+
- data/rec/train.csv (val.csv, test.csv exist but are NOT used for sequences)
|
| 17 |
|
| 18 |
Output:
|
| 19 |
+
- data/rec/user_sequences.pkl (Dict[user_id, List[item_id]]) — train-only
|
| 20 |
+
- data/rec/item_map.pkl (Dict[isbn, item_id]) — train-only
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
"""
|
| 22 |
|
| 23 |
import pandas as pd
|
|
|
|
| 24 |
import pickle
|
| 25 |
import logging
|
| 26 |
from pathlib import Path
|
|
|
|
| 29 |
logging.basicConfig(level=logging.INFO)
|
| 30 |
logger = logging.getLogger(__name__)
|
| 31 |
|
| 32 |
+
|
| 33 |
+
def build_sequences(data_dir="data/rec", max_len=50):
|
| 34 |
"""
|
| 35 |
+
Build user sequences from train.csv only (strict time-split).
|
| 36 |
+
Val/test are excluded to avoid leakage in ranking features (sasrec_score).
|
| 37 |
"""
|
| 38 |
+
logger.info("Building user sequences (train-only, time-split)...")
|
| 39 |
+
train_df = pd.read_csv(f"{data_dir}/train.csv")
|
| 40 |
+
|
| 41 |
+
# 1. Item map from train only (matches SASRec/YoutubeDNN)
|
| 42 |
+
items = train_df["isbn"].unique()
|
| 43 |
+
item_map = {isbn: i + 1 for i, isbn in enumerate(items)}
|
| 44 |
+
logger.info(" Items (train): %d", len(item_map))
|
| 45 |
+
|
| 46 |
+
# 2. User history from train only (no val/test)
|
| 47 |
+
user_history = {}
|
| 48 |
+
if "timestamp" in train_df.columns:
|
| 49 |
+
train_df = train_df.sort_values(["user_id", "timestamp"])
|
| 50 |
+
for _, row in tqdm(train_df.iterrows(), total=len(train_df), desc=" Processing"):
|
| 51 |
+
u = str(row["user_id"])
|
| 52 |
+
item = item_map.get(row["isbn"])
|
| 53 |
+
if item is None:
|
| 54 |
+
continue
|
| 55 |
+
if u not in user_history:
|
| 56 |
+
user_history[u] = []
|
| 57 |
+
user_history[u].append(item)
|
| 58 |
+
|
| 59 |
+
final_seqs = {u: hist[-max_len:] for u, hist in user_history.items()}
|
| 60 |
+
logger.info(" Users: %d", len(final_seqs))
|
| 61 |
+
|
| 62 |
+
data_dir = Path(data_dir)
|
| 63 |
+
data_dir.mkdir(parents=True, exist_ok=True)
|
| 64 |
+
with open(data_dir / "item_map.pkl", "wb") as f:
|
| 65 |
pickle.dump(item_map, f)
|
| 66 |
+
with open(data_dir / "user_sequences.pkl", "wb") as f:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
pickle.dump(final_seqs, f)
|
| 68 |
+
logger.info("Sequence data saved (train-only).")
|
|
|
|
| 69 |
|
| 70 |
if __name__ == "__main__":
|
| 71 |
build_sequences()
|
scripts/data/fetch_new_books.py
ADDED
|
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Incremental Book Update Script.
|
| 4 |
+
|
| 5 |
+
Fetches recently published books from Google Books API and adds them to the local database.
|
| 6 |
+
Can be run manually or scheduled via cron for periodic updates.
|
| 7 |
+
|
| 8 |
+
Usage:
|
| 9 |
+
python scripts/data/fetch_new_books.py [--categories CATEGORIES] [--year YEAR] [--max MAX]
|
| 10 |
+
|
| 11 |
+
Examples:
|
| 12 |
+
# Fetch new books from current year (default behavior)
|
| 13 |
+
python scripts/data/fetch_new_books.py --categories "fiction" --max 50
|
| 14 |
+
|
| 15 |
+
# Fetch new books across multiple categories
|
| 16 |
+
python scripts/data/fetch_new_books.py --categories "fiction,mystery,science fiction"
|
| 17 |
+
|
| 18 |
+
# Explicitly specify year filter
|
| 19 |
+
python scripts/data/fetch_new_books.py --year 2026 --categories "thriller"
|
| 20 |
+
|
| 21 |
+
# Dry run (show what would be added without actually adding)
|
| 22 |
+
python scripts/data/fetch_new_books.py --dry-run --categories "thriller"
|
| 23 |
+
"""
|
| 24 |
+
import argparse
|
| 25 |
+
import sys
|
| 26 |
+
import time
|
| 27 |
+
from pathlib import Path
|
| 28 |
+
from datetime import datetime
|
| 29 |
+
from typing import Optional
|
| 30 |
+
|
| 31 |
+
# Add project root to path
|
| 32 |
+
PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
|
| 33 |
+
sys.path.insert(0, str(PROJECT_ROOT))
|
| 34 |
+
|
| 35 |
+
from src.utils import setup_logger
|
| 36 |
+
from src.core.web_search import search_new_books_by_category, search_google_books
|
| 37 |
+
from src.core.metadata_store import metadata_store
|
| 38 |
+
from src.recommender import BookRecommender
|
| 39 |
+
|
| 40 |
+
logger = setup_logger(__name__)
|
| 41 |
+
|
| 42 |
+
# Default categories to search
|
| 43 |
+
DEFAULT_CATEGORIES = [
|
| 44 |
+
"fiction",
|
| 45 |
+
"mystery",
|
| 46 |
+
"thriller",
|
| 47 |
+
"science fiction",
|
| 48 |
+
"fantasy",
|
| 49 |
+
"romance",
|
| 50 |
+
"biography",
|
| 51 |
+
"history",
|
| 52 |
+
"self-help",
|
| 53 |
+
"business",
|
| 54 |
+
]
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def fetch_trending_books(
|
| 58 |
+
categories: list[str],
|
| 59 |
+
year: Optional[int] = None,
|
| 60 |
+
max_per_category: int = 20,
|
| 61 |
+
dry_run: bool = False,
|
| 62 |
+
) -> dict:
|
| 63 |
+
"""
|
| 64 |
+
Fetch recently published books from Google Books for given categories.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
categories: List of book categories to search
|
| 68 |
+
year: Filter by publication year (default: current year)
|
| 69 |
+
max_per_category: Max books to fetch per category
|
| 70 |
+
dry_run: If True, don't actually add books to database
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
Dict with stats: {added: int, skipped: int, errors: int, books: list}
|
| 74 |
+
"""
|
| 75 |
+
if year is None:
|
| 76 |
+
year = datetime.now().year
|
| 77 |
+
|
| 78 |
+
stats = {
|
| 79 |
+
"added": 0,
|
| 80 |
+
"skipped": 0,
|
| 81 |
+
"errors": 0,
|
| 82 |
+
"books": [],
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
recommender = None
|
| 86 |
+
if not dry_run:
|
| 87 |
+
recommender = BookRecommender()
|
| 88 |
+
|
| 89 |
+
for category in categories:
|
| 90 |
+
logger.info(f"Fetching books for category: {category} (year >= {year})")
|
| 91 |
+
|
| 92 |
+
try:
|
| 93 |
+
books = search_new_books_by_category(
|
| 94 |
+
category=category,
|
| 95 |
+
year=year,
|
| 96 |
+
max_results=max_per_category
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
logger.info(f" Found {len(books)} books in '{category}'")
|
| 100 |
+
|
| 101 |
+
for book in books:
|
| 102 |
+
isbn = book.get("isbn13", "")
|
| 103 |
+
if not isbn:
|
| 104 |
+
continue
|
| 105 |
+
|
| 106 |
+
# Check if already exists
|
| 107 |
+
if metadata_store.book_exists(isbn):
|
| 108 |
+
stats["skipped"] += 1
|
| 109 |
+
continue
|
| 110 |
+
|
| 111 |
+
if dry_run:
|
| 112 |
+
logger.info(f" [DRY RUN] Would add: {book.get('title', 'Unknown')} ({isbn})")
|
| 113 |
+
stats["books"].append(book)
|
| 114 |
+
stats["added"] += 1
|
| 115 |
+
else:
|
| 116 |
+
result = recommender.add_new_book(
|
| 117 |
+
isbn=isbn,
|
| 118 |
+
title=book.get("title", ""),
|
| 119 |
+
author=book.get("authors", "Unknown"),
|
| 120 |
+
description=book.get("description", ""),
|
| 121 |
+
category=book.get("simple_categories", category),
|
| 122 |
+
thumbnail=book.get("thumbnail"),
|
| 123 |
+
published_date=book.get("publishedDate", ""),
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
if result:
|
| 127 |
+
stats["added"] += 1
|
| 128 |
+
stats["books"].append(book)
|
| 129 |
+
logger.info(f" Added: {book.get('title', 'Unknown')} ({isbn})")
|
| 130 |
+
else:
|
| 131 |
+
stats["errors"] += 1
|
| 132 |
+
|
| 133 |
+
# Rate limiting: avoid hitting API limits
|
| 134 |
+
time.sleep(0.1)
|
| 135 |
+
|
| 136 |
+
# Pause between categories
|
| 137 |
+
time.sleep(0.5)
|
| 138 |
+
|
| 139 |
+
except Exception as e:
|
| 140 |
+
logger.error(f"Error fetching category '{category}': {e}")
|
| 141 |
+
stats["errors"] += 1
|
| 142 |
+
|
| 143 |
+
return stats
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def fetch_by_query(
|
| 147 |
+
queries: list[str],
|
| 148 |
+
max_per_query: int = 20,
|
| 149 |
+
dry_run: bool = False,
|
| 150 |
+
) -> dict:
|
| 151 |
+
"""
|
| 152 |
+
Fetch books by specific search queries (e.g., "AI books 2024", "new thriller novels").
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
queries: List of search queries
|
| 156 |
+
max_per_query: Max books per query
|
| 157 |
+
dry_run: If True, don't actually add books
|
| 158 |
+
|
| 159 |
+
Returns:
|
| 160 |
+
Stats dict
|
| 161 |
+
"""
|
| 162 |
+
stats = {
|
| 163 |
+
"added": 0,
|
| 164 |
+
"skipped": 0,
|
| 165 |
+
"errors": 0,
|
| 166 |
+
"books": [],
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
recommender = None
|
| 170 |
+
if not dry_run:
|
| 171 |
+
recommender = BookRecommender()
|
| 172 |
+
|
| 173 |
+
for query in queries:
|
| 174 |
+
logger.info(f"Searching: {query}")
|
| 175 |
+
|
| 176 |
+
try:
|
| 177 |
+
books = search_google_books(query, max_results=max_per_query)
|
| 178 |
+
logger.info(f" Found {len(books)} results")
|
| 179 |
+
|
| 180 |
+
for book in books:
|
| 181 |
+
isbn = book.get("isbn13", "")
|
| 182 |
+
if not isbn:
|
| 183 |
+
continue
|
| 184 |
+
|
| 185 |
+
if metadata_store.book_exists(isbn):
|
| 186 |
+
stats["skipped"] += 1
|
| 187 |
+
continue
|
| 188 |
+
|
| 189 |
+
if dry_run:
|
| 190 |
+
logger.info(f" [DRY RUN] Would add: {book.get('title', 'Unknown')}")
|
| 191 |
+
stats["books"].append(book)
|
| 192 |
+
stats["added"] += 1
|
| 193 |
+
else:
|
| 194 |
+
result = recommender.add_new_book(
|
| 195 |
+
isbn=isbn,
|
| 196 |
+
title=book.get("title", ""),
|
| 197 |
+
author=book.get("authors", "Unknown"),
|
| 198 |
+
description=book.get("description", ""),
|
| 199 |
+
category=book.get("simple_categories", "General"),
|
| 200 |
+
thumbnail=book.get("thumbnail"),
|
| 201 |
+
published_date=book.get("publishedDate", ""),
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
if result:
|
| 205 |
+
stats["added"] += 1
|
| 206 |
+
stats["books"].append(book)
|
| 207 |
+
else:
|
| 208 |
+
stats["errors"] += 1
|
| 209 |
+
|
| 210 |
+
time.sleep(0.1)
|
| 211 |
+
|
| 212 |
+
time.sleep(0.5)
|
| 213 |
+
|
| 214 |
+
except Exception as e:
|
| 215 |
+
logger.error(f"Error with query '{query}': {e}")
|
| 216 |
+
stats["errors"] += 1
|
| 217 |
+
|
| 218 |
+
return stats
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def print_stats(stats: dict, dry_run: bool = False):
|
| 222 |
+
"""Print summary statistics."""
|
| 223 |
+
prefix = "[DRY RUN] " if dry_run else ""
|
| 224 |
+
print(f"\n{prefix}=== Fetch Complete ===")
|
| 225 |
+
print(f" Books added: {stats['added']}")
|
| 226 |
+
print(f" Books skipped: {stats['skipped']} (already in database)")
|
| 227 |
+
print(f" Errors: {stats['errors']}")
|
| 228 |
+
|
| 229 |
+
if stats["books"] and dry_run:
|
| 230 |
+
print(f"\nBooks that would be added:")
|
| 231 |
+
for book in stats["books"][:10]:
|
| 232 |
+
print(f" - {book.get('title', 'Unknown')} by {book.get('authors', 'Unknown')}")
|
| 233 |
+
if len(stats["books"]) > 10:
|
| 234 |
+
print(f" ... and {len(stats['books']) - 10} more")
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def main():
|
| 238 |
+
parser = argparse.ArgumentParser(
|
| 239 |
+
description="Fetch new books from Google Books API",
|
| 240 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 241 |
+
epilog=__doc__
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
parser.add_argument(
|
| 245 |
+
"--categories",
|
| 246 |
+
type=str,
|
| 247 |
+
default=None,
|
| 248 |
+
help="Comma-separated list of categories (default: all common categories)"
|
| 249 |
+
)
|
| 250 |
+
parser.add_argument(
|
| 251 |
+
"--queries",
|
| 252 |
+
type=str,
|
| 253 |
+
default=None,
|
| 254 |
+
help="Comma-separated list of custom search queries"
|
| 255 |
+
)
|
| 256 |
+
parser.add_argument(
|
| 257 |
+
"--year",
|
| 258 |
+
type=int,
|
| 259 |
+
default=None,
|
| 260 |
+
help="Filter by publication year (default: current year)"
|
| 261 |
+
)
|
| 262 |
+
parser.add_argument(
|
| 263 |
+
"--max",
|
| 264 |
+
type=int,
|
| 265 |
+
default=20,
|
| 266 |
+
help="Max books per category/query (default: 20)"
|
| 267 |
+
)
|
| 268 |
+
parser.add_argument(
|
| 269 |
+
"--dry-run",
|
| 270 |
+
action="store_true",
|
| 271 |
+
help="Show what would be added without actually adding"
|
| 272 |
+
)
|
| 273 |
+
parser.add_argument(
|
| 274 |
+
"--verbose",
|
| 275 |
+
"-v",
|
| 276 |
+
action="store_true",
|
| 277 |
+
help="Enable verbose logging"
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
args = parser.parse_args()
|
| 281 |
+
|
| 282 |
+
# Parse categories
|
| 283 |
+
if args.categories:
|
| 284 |
+
categories = [c.strip() for c in args.categories.split(",")]
|
| 285 |
+
else:
|
| 286 |
+
categories = DEFAULT_CATEGORIES
|
| 287 |
+
|
| 288 |
+
# Parse queries
|
| 289 |
+
queries = None
|
| 290 |
+
if args.queries:
|
| 291 |
+
queries = [q.strip() for q in args.queries.split(",")]
|
| 292 |
+
|
| 293 |
+
print(f"Book Fetch Configuration:")
|
| 294 |
+
print(f" Categories: {categories if not queries else 'N/A (using queries)'}")
|
| 295 |
+
print(f" Queries: {queries or 'N/A (using categories)'}")
|
| 296 |
+
print(f" Year filter: >= {args.year or datetime.now().year}")
|
| 297 |
+
print(f" Max per item: {args.max}")
|
| 298 |
+
print(f" Dry run: {args.dry_run}")
|
| 299 |
+
print()
|
| 300 |
+
|
| 301 |
+
# Fetch books
|
| 302 |
+
if queries:
|
| 303 |
+
stats = fetch_by_query(
|
| 304 |
+
queries=queries,
|
| 305 |
+
max_per_query=args.max,
|
| 306 |
+
dry_run=args.dry_run,
|
| 307 |
+
)
|
| 308 |
+
else:
|
| 309 |
+
stats = fetch_trending_books(
|
| 310 |
+
categories=categories,
|
| 311 |
+
year=args.year,
|
| 312 |
+
max_per_category=args.max,
|
| 313 |
+
dry_run=args.dry_run,
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
print_stats(stats, args.dry_run)
|
| 317 |
+
|
| 318 |
+
return 0 if stats["errors"] == 0 else 1
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
if __name__ == "__main__":
|
| 322 |
+
sys.exit(main())
|
scripts/data/validate_data.py
CHANGED
|
@@ -144,9 +144,27 @@ def validate_rec():
|
|
| 144 |
print(f" User sequences: {len(seqs):,}")
|
| 145 |
avg_len = np.mean([len(s) for s in seqs.values()])
|
| 146 |
print(f" Avg sequence length: {avg_len:.1f}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
else:
|
| 148 |
print(" ⚠️ User sequences not found (run build_sequences.py)")
|
| 149 |
-
|
| 150 |
print(" ✅ Rec data validation passed")
|
| 151 |
return True
|
| 152 |
|
|
|
|
| 144 |
print(f" User sequences: {len(seqs):,}")
|
| 145 |
avg_len = np.mean([len(s) for s in seqs.values()])
|
| 146 |
print(f" Avg sequence length: {avg_len:.1f}")
|
| 147 |
+
|
| 148 |
+
# Time-split: no val items in sequences (prevents sasrec_score leakage)
|
| 149 |
+
if ITEM_MAP.exists():
|
| 150 |
+
with open(ITEM_MAP, "rb") as f:
|
| 151 |
+
item_map = pickle.load(f)
|
| 152 |
+
id_to_item = {v: k for k, v in item_map.items()}
|
| 153 |
+
leaked = 0
|
| 154 |
+
for _, row in val.iterrows():
|
| 155 |
+
uid, val_isbn = str(row["user_id"]), str(row["isbn"])
|
| 156 |
+
if uid not in seqs:
|
| 157 |
+
continue
|
| 158 |
+
val_iid = item_map.get(val_isbn)
|
| 159 |
+
if val_iid is None:
|
| 160 |
+
continue # val item not in map (train-only) -> no leak possible
|
| 161 |
+
if val_iid in seqs[uid]:
|
| 162 |
+
leaked += 1
|
| 163 |
+
check(leaked == 0, f"Time-split violation: {leaked} users have val items in sequence")
|
| 164 |
+
print(" ✅ Time-split OK (no val in sequences)")
|
| 165 |
else:
|
| 166 |
print(" ⚠️ User sequences not found (run build_sequences.py)")
|
| 167 |
+
|
| 168 |
print(" ✅ Rec data validation passed")
|
| 169 |
return True
|
| 170 |
|
scripts/model/evaluate_rag.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Evaluate RAG retrieval on a Golden Test Set.
|
| 4 |
+
|
| 5 |
+
Replaces "curated examples" with quantitative metrics: Accuracy@K, Recall@K, MRR@K.
|
| 6 |
+
Use human-annotated Query-Book pairs for data-driven evaluation.
|
| 7 |
+
|
| 8 |
+
Usage:
|
| 9 |
+
python scripts/model/evaluate_rag.py
|
| 10 |
+
python scripts/model/evaluate_rag.py --golden data/rag_golden.csv --top_k 10
|
| 11 |
+
|
| 12 |
+
Golden set format (CSV): query, isbn, relevance
|
| 13 |
+
- query: user search string
|
| 14 |
+
- isbn: expected relevant book (1=relevant)
|
| 15 |
+
- Multiple rows per query = multiple relevant books
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import sys
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
|
| 21 |
+
sys.path.insert(0, str(Path(__file__).resolve().parent.parent.parent))
|
| 22 |
+
|
| 23 |
+
import pandas as pd
|
| 24 |
+
import logging
|
| 25 |
+
from collections import defaultdict
|
| 26 |
+
|
| 27 |
+
from src.recommender import BookRecommender
|
| 28 |
+
|
| 29 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
| 30 |
+
logger = logging.getLogger(__name__)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def load_golden(path: Path) -> dict[str, set[str]]:
|
| 34 |
+
"""Load golden set: {query -> set of relevant isbns}."""
|
| 35 |
+
df = pd.read_csv(path, comment="#")
|
| 36 |
+
df = df[df.get("relevance", 1) == 1] # Only relevant pairs
|
| 37 |
+
golden = defaultdict(set)
|
| 38 |
+
for _, row in df.iterrows():
|
| 39 |
+
q = str(row["query"]).strip()
|
| 40 |
+
isbn = str(row["isbn"]).strip().replace(".0", "")
|
| 41 |
+
if q and isbn:
|
| 42 |
+
golden[q].add(isbn)
|
| 43 |
+
return dict(golden)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def evaluate_rag(
|
| 47 |
+
golden_path: Path | str = "data/rag_golden.csv",
|
| 48 |
+
top_k: int = 10,
|
| 49 |
+
use_title_match: bool = True,
|
| 50 |
+
) -> dict:
|
| 51 |
+
"""
|
| 52 |
+
Run RAG retrieval on golden set and compute metrics.
|
| 53 |
+
|
| 54 |
+
Returns: dict with accuracy_at_k, recall_at_k, mrr_at_k, n_queries
|
| 55 |
+
"""
|
| 56 |
+
golden_path = Path(golden_path)
|
| 57 |
+
if not golden_path.exists():
|
| 58 |
+
# Fallback to example
|
| 59 |
+
alt = Path("data/rag_golden.example.csv")
|
| 60 |
+
if alt.exists():
|
| 61 |
+
logger.warning("Golden set not found at %s, using %s", golden_path, alt)
|
| 62 |
+
golden_path = alt
|
| 63 |
+
else:
|
| 64 |
+
raise FileNotFoundError(
|
| 65 |
+
f"Golden set not found. Create {golden_path} with columns: query,isbn,relevance. "
|
| 66 |
+
"See data/rag_golden.example.csv for format."
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
golden = load_golden(golden_path)
|
| 70 |
+
if not golden:
|
| 71 |
+
raise ValueError("Golden set is empty")
|
| 72 |
+
|
| 73 |
+
logger.info("Evaluating RAG on %d queries from %s", len(golden), golden_path)
|
| 74 |
+
|
| 75 |
+
recommender = BookRecommender()
|
| 76 |
+
isbn_to_title = {}
|
| 77 |
+
if use_title_match:
|
| 78 |
+
try:
|
| 79 |
+
bp = Path("data/books_processed.csv")
|
| 80 |
+
if not bp.exists():
|
| 81 |
+
bp = Path(__file__).resolve().parent.parent.parent / "data" / "books_processed.csv"
|
| 82 |
+
books = pd.read_csv(bp, usecols=["isbn13", "title"])
|
| 83 |
+
books["isbn13"] = books["isbn13"].astype(str).str.replace(r"\.0$", "", regex=True)
|
| 84 |
+
isbn_to_title = books.set_index("isbn13")["title"].to_dict()
|
| 85 |
+
except Exception as e:
|
| 86 |
+
logger.warning("Could not load title map: %s", e)
|
| 87 |
+
use_title_match = False
|
| 88 |
+
|
| 89 |
+
hits_acc = 0
|
| 90 |
+
recall_sum = 0.0
|
| 91 |
+
mrr_sum = 0.0
|
| 92 |
+
|
| 93 |
+
for query, relevant_isbns in golden.items():
|
| 94 |
+
try:
|
| 95 |
+
recs = recommender.get_recommendations(query, top_k=top_k * 2)
|
| 96 |
+
rec_isbns = [r.get("isbn") or r.get("isbn13") for r in recs if r]
|
| 97 |
+
rec_isbns = [str(x).replace(".0", "") for x in rec_isbns if pd.notna(x)]
|
| 98 |
+
rec_top = rec_isbns[:top_k]
|
| 99 |
+
|
| 100 |
+
# Match: exact or title
|
| 101 |
+
def _match(target: str, cand_list: list) -> int:
|
| 102 |
+
for i, c in enumerate(cand_list):
|
| 103 |
+
if str(c).strip() == str(target).strip():
|
| 104 |
+
return i
|
| 105 |
+
if use_title_match:
|
| 106 |
+
t_title = isbn_to_title.get(str(target), "").lower().strip()
|
| 107 |
+
c_title = isbn_to_title.get(str(c), "").lower().strip()
|
| 108 |
+
if t_title and c_title and t_title == c_title:
|
| 109 |
+
return i
|
| 110 |
+
return -1
|
| 111 |
+
|
| 112 |
+
# Accuracy@K: at least one relevant in top-K
|
| 113 |
+
found_any = False
|
| 114 |
+
first_rank = top_k + 1
|
| 115 |
+
count_in_top = 0
|
| 116 |
+
|
| 117 |
+
for rel in relevant_isbns:
|
| 118 |
+
rk = _match(rel, rec_top)
|
| 119 |
+
if rk >= 0:
|
| 120 |
+
found_any = True
|
| 121 |
+
count_in_top += 1
|
| 122 |
+
first_rank = min(first_rank, rk + 1)
|
| 123 |
+
|
| 124 |
+
if found_any:
|
| 125 |
+
hits_acc += 1
|
| 126 |
+
recall_sum += count_in_top / len(relevant_isbns) if relevant_isbns else 0
|
| 127 |
+
if first_rank <= top_k:
|
| 128 |
+
mrr_sum += 1.0 / first_rank
|
| 129 |
+
|
| 130 |
+
except Exception as e:
|
| 131 |
+
logger.warning("Query %r failed: %s", query[:50], e)
|
| 132 |
+
|
| 133 |
+
n = len(golden)
|
| 134 |
+
return {
|
| 135 |
+
"accuracy_at_k": hits_acc / n,
|
| 136 |
+
"recall_at_k": recall_sum / n,
|
| 137 |
+
"mrr_at_k": mrr_sum / n,
|
| 138 |
+
"n_queries": n,
|
| 139 |
+
"top_k": top_k,
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def main():
|
| 144 |
+
import argparse
|
| 145 |
+
parser = argparse.ArgumentParser(description="Evaluate RAG on Golden Test Set")
|
| 146 |
+
parser.add_argument("--golden", default="data/rag_golden.csv", help="Path to golden CSV")
|
| 147 |
+
parser.add_argument("--top_k", type=int, default=10)
|
| 148 |
+
parser.add_argument("--no-title-match", action="store_true", help="Disable relaxed title matching")
|
| 149 |
+
args = parser.parse_args()
|
| 150 |
+
|
| 151 |
+
m = evaluate_rag(
|
| 152 |
+
golden_path=args.golden,
|
| 153 |
+
top_k=args.top_k,
|
| 154 |
+
use_title_match=not args.no_title_match,
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
print("\n" + "=" * 50)
|
| 158 |
+
print(" RAG Golden Test Set Evaluation")
|
| 159 |
+
print("=" * 50)
|
| 160 |
+
print(f" Queries: {m['n_queries']}")
|
| 161 |
+
print(f" Top-K: {m['top_k']}")
|
| 162 |
+
print(f" Accuracy@{m['top_k']}: {m['accuracy_at_k']:.4f} (any relevant in top-K)")
|
| 163 |
+
print(f" Recall@{m['top_k']}: {m['recall_at_k']:.4f} (fraction of relevant in top-K)")
|
| 164 |
+
print(f" MRR@{m['top_k']}: {m['mrr_at_k']:.4f} (mean reciprocal rank)")
|
| 165 |
+
print("=" * 50)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
if __name__ == "__main__":
|
| 169 |
+
main()
|
scripts/model/train_din_ranker.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Train DIN (Deep Interest Network) ranker.
|
| 4 |
+
|
| 5 |
+
Uses attention over user behavior sequence w.r.t. target item.
|
| 6 |
+
Reuses SASRec item embeddings as initialization when available.
|
| 7 |
+
|
| 8 |
+
Usage:
|
| 9 |
+
python scripts/model/train_din_ranker.py
|
| 10 |
+
python scripts/model/train_din_ranker.py --max_samples 10000 --epochs 10
|
| 11 |
+
|
| 12 |
+
Input:
|
| 13 |
+
- data/rec/val.csv, train.csv
|
| 14 |
+
- data/rec/user_sequences.pkl, item_map.pkl (from SASRec/YoutubeDNN)
|
| 15 |
+
- data/model/rec/sasrec_model.pth (optional, for init)
|
| 16 |
+
|
| 17 |
+
Output:
|
| 18 |
+
- data/model/ranking/din_ranker.pt
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import sys
|
| 22 |
+
import os
|
| 23 |
+
|
| 24 |
+
sys.path.append(os.getcwd())
|
| 25 |
+
|
| 26 |
+
import pickle
|
| 27 |
+
import logging
|
| 28 |
+
from pathlib import Path
|
| 29 |
+
|
| 30 |
+
import numpy as np
|
| 31 |
+
import pandas as pd
|
| 32 |
+
import torch
|
| 33 |
+
import torch.nn as nn
|
| 34 |
+
import torch.nn.functional as F
|
| 35 |
+
from torch.utils.data import Dataset, DataLoader
|
| 36 |
+
from tqdm import tqdm
|
| 37 |
+
|
| 38 |
+
from src.ranking.din import DIN
|
| 39 |
+
from src.recall.fusion import RecallFusion
|
| 40 |
+
|
| 41 |
+
logging.basicConfig(
|
| 42 |
+
level=logging.INFO,
|
| 43 |
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
| 44 |
+
)
|
| 45 |
+
logger = logging.getLogger(__name__)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def build_din_data(
|
| 49 |
+
data_dir: str = "data/rec",
|
| 50 |
+
model_dir: str = "data/model/recall",
|
| 51 |
+
neg_ratio: int = 4,
|
| 52 |
+
max_samples: int = 20000,
|
| 53 |
+
) -> tuple[pd.DataFrame, dict, dict]:
|
| 54 |
+
"""
|
| 55 |
+
Build (user_id, isbn, label) samples with hard negatives.
|
| 56 |
+
Returns (df, user_sequences, item_map).
|
| 57 |
+
"""
|
| 58 |
+
logger.info("Building DIN training data...")
|
| 59 |
+
val_df = pd.read_csv(f"{data_dir}/val.csv")
|
| 60 |
+
all_items = pd.read_csv(f"{data_dir}/train.csv")["isbn"].astype(str).unique()
|
| 61 |
+
|
| 62 |
+
if len(val_df) > max_samples:
|
| 63 |
+
val_df = val_df.sample(n=max_samples, random_state=42).reset_index(drop=True)
|
| 64 |
+
|
| 65 |
+
fusion = RecallFusion(data_dir, model_dir)
|
| 66 |
+
fusion.load_models()
|
| 67 |
+
|
| 68 |
+
with open(f"{data_dir}/user_sequences.pkl", "rb") as f:
|
| 69 |
+
user_sequences = pickle.load(f)
|
| 70 |
+
with open(f"{data_dir}/item_map.pkl", "rb") as f:
|
| 71 |
+
item_map = pickle.load(f)
|
| 72 |
+
|
| 73 |
+
rows = []
|
| 74 |
+
for _, row in tqdm(val_df.iterrows(), total=len(val_df), desc="Mining samples"):
|
| 75 |
+
user_id = str(row["user_id"])
|
| 76 |
+
pos_isbn = str(row["isbn"])
|
| 77 |
+
|
| 78 |
+
user_rows = [{"user_id": user_id, "isbn": pos_isbn, "label": 1}]
|
| 79 |
+
|
| 80 |
+
try:
|
| 81 |
+
recall_items = fusion.get_recall_items(user_id, k=50)
|
| 82 |
+
hard_negs = [item for item, _ in recall_items if item != pos_isbn][:neg_ratio]
|
| 83 |
+
except Exception:
|
| 84 |
+
hard_negs = []
|
| 85 |
+
|
| 86 |
+
for neg_isbn in hard_negs:
|
| 87 |
+
user_rows.append({"user_id": user_id, "isbn": str(neg_isbn), "label": 0})
|
| 88 |
+
|
| 89 |
+
n_remaining = neg_ratio - len(hard_negs)
|
| 90 |
+
if n_remaining > 0:
|
| 91 |
+
random_negs = np.random.choice(all_items, size=n_remaining, replace=False)
|
| 92 |
+
for neg_isbn in random_negs:
|
| 93 |
+
user_rows.append({"user_id": user_id, "isbn": str(neg_isbn), "label": 0})
|
| 94 |
+
|
| 95 |
+
rows.extend(user_rows)
|
| 96 |
+
|
| 97 |
+
df = pd.DataFrame(rows)
|
| 98 |
+
logger.info(f"Built {len(df)} samples")
|
| 99 |
+
return df, user_sequences, item_map
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class DINDataset(Dataset):
|
| 103 |
+
"""Dataset for DIN: (user_hist, target_item_id, label) and optional aux features."""
|
| 104 |
+
|
| 105 |
+
def __init__(
|
| 106 |
+
self,
|
| 107 |
+
df: pd.DataFrame,
|
| 108 |
+
user_sequences: dict,
|
| 109 |
+
item_map: dict,
|
| 110 |
+
max_hist_len: int = 50,
|
| 111 |
+
aux_df: pd.DataFrame | None = None,
|
| 112 |
+
aux_cols: list[str] | None = None,
|
| 113 |
+
):
|
| 114 |
+
self.samples = []
|
| 115 |
+
self.aux_df = aux_df
|
| 116 |
+
self.aux_cols = aux_cols or []
|
| 117 |
+
for idx, (_, row) in enumerate(df.iterrows()):
|
| 118 |
+
user_id = str(row["user_id"])
|
| 119 |
+
isbn = str(row["isbn"])
|
| 120 |
+
label = int(row["label"])
|
| 121 |
+
target_id = item_map.get(isbn, 0)
|
| 122 |
+
if target_id == 0:
|
| 123 |
+
continue
|
| 124 |
+
|
| 125 |
+
hist = user_sequences.get(user_id, [])
|
| 126 |
+
if hist and isinstance(hist[0], str):
|
| 127 |
+
hist = [item_map.get(h, 0) for h in hist if item_map.get(h, 0) > 0]
|
| 128 |
+
hist = [x for x in hist if x != target_id][-max_hist_len:]
|
| 129 |
+
|
| 130 |
+
self.samples.append((hist, target_id, label, idx))
|
| 131 |
+
|
| 132 |
+
def __len__(self) -> int:
|
| 133 |
+
return len(self.samples)
|
| 134 |
+
|
| 135 |
+
def __getitem__(self, idx: int):
|
| 136 |
+
hist, target_id, label, df_idx = self.samples[idx]
|
| 137 |
+
max_len = 50
|
| 138 |
+
padded = np.zeros(max_len, dtype=np.int64)
|
| 139 |
+
padded[: len(hist)] = hist
|
| 140 |
+
out = (
|
| 141 |
+
torch.LongTensor(padded),
|
| 142 |
+
torch.LongTensor([target_id]).squeeze(0),
|
| 143 |
+
torch.FloatTensor([label]).squeeze(0),
|
| 144 |
+
)
|
| 145 |
+
if self.aux_df is not None and self.aux_cols:
|
| 146 |
+
aux_row = self.aux_df.iloc[df_idx][self.aux_cols].values.astype(np.float32)
|
| 147 |
+
out = out + (torch.FloatTensor(aux_row),)
|
| 148 |
+
return out
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def train_din(
|
| 152 |
+
data_dir: str = "data/rec",
|
| 153 |
+
model_dir: str = "data/model",
|
| 154 |
+
recall_dir: str = "data/model/recall",
|
| 155 |
+
max_samples: int = 20000,
|
| 156 |
+
max_hist_len: int = 50,
|
| 157 |
+
embed_dim: int = 64,
|
| 158 |
+
epochs: int = 10,
|
| 159 |
+
batch_size: int = 256,
|
| 160 |
+
lr: float = 1e-3,
|
| 161 |
+
use_aux: bool = False,
|
| 162 |
+
) -> None:
|
| 163 |
+
rank_dir = Path(model_dir) / "ranking"
|
| 164 |
+
rank_dir.mkdir(parents=True, exist_ok=True)
|
| 165 |
+
|
| 166 |
+
df, user_sequences, item_map = build_din_data(
|
| 167 |
+
data_dir, recall_dir, neg_ratio=4, max_samples=max_samples
|
| 168 |
+
)
|
| 169 |
+
num_items = len(item_map)
|
| 170 |
+
|
| 171 |
+
aux_df = None
|
| 172 |
+
aux_cols: list[str] = []
|
| 173 |
+
if use_aux:
|
| 174 |
+
from src.ranking.features import FeatureEngineer
|
| 175 |
+
fe = FeatureEngineer(data_dir, recall_dir)
|
| 176 |
+
fe.load_base_data()
|
| 177 |
+
logger.info("Generating aux features for DIN...")
|
| 178 |
+
aux_df = fe.create_dateset(df)
|
| 179 |
+
aux_cols = [c for c in aux_df.columns if c not in ("label", "user_id", "isbn")]
|
| 180 |
+
logger.info("Aux features: %s", aux_cols)
|
| 181 |
+
|
| 182 |
+
num_aux = len(aux_cols)
|
| 183 |
+
dataset = DINDataset(
|
| 184 |
+
df, user_sequences, item_map,
|
| 185 |
+
max_hist_len=max_hist_len,
|
| 186 |
+
aux_df=aux_df,
|
| 187 |
+
aux_cols=aux_cols if aux_cols else None,
|
| 188 |
+
)
|
| 189 |
+
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
|
| 190 |
+
|
| 191 |
+
pretrained_emb = None
|
| 192 |
+
sasrec_path = Path(model_dir) / "rec" / "sasrec_model.pth"
|
| 193 |
+
if sasrec_path.exists():
|
| 194 |
+
try:
|
| 195 |
+
state = torch.load(sasrec_path, map_location="cpu", weights_only=False)
|
| 196 |
+
emb = state.get("item_emb.weight")
|
| 197 |
+
if emb is not None:
|
| 198 |
+
pretrained_emb = emb.numpy()
|
| 199 |
+
logger.info("Loaded SASRec item_emb for DIN init: %s", pretrained_emb.shape)
|
| 200 |
+
except Exception as e:
|
| 201 |
+
logger.warning("Could not load SASRec init: %s", e)
|
| 202 |
+
|
| 203 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 204 |
+
if torch.backends.mps.is_available():
|
| 205 |
+
device = torch.device("mps")
|
| 206 |
+
|
| 207 |
+
model = DIN(
|
| 208 |
+
num_items=num_items,
|
| 209 |
+
embed_dim=embed_dim,
|
| 210 |
+
max_hist_len=max_hist_len,
|
| 211 |
+
num_aux=num_aux,
|
| 212 |
+
pretrained_item_emb=pretrained_emb,
|
| 213 |
+
).to(device)
|
| 214 |
+
opt = torch.optim.Adam(model.parameters(), lr=lr)
|
| 215 |
+
|
| 216 |
+
for ep in range(epochs):
|
| 217 |
+
model.train()
|
| 218 |
+
total_loss = 0.0
|
| 219 |
+
n_batches = 0
|
| 220 |
+
for batch in tqdm(loader, desc=f"Epoch {ep+1}/{epochs}"):
|
| 221 |
+
hist = batch[0].to(device)
|
| 222 |
+
target = batch[1].to(device)
|
| 223 |
+
label = batch[2].to(device)
|
| 224 |
+
aux = batch[3].to(device) if len(batch) > 3 else None
|
| 225 |
+
opt.zero_grad()
|
| 226 |
+
logits = model(hist, target, aux)
|
| 227 |
+
loss = F.binary_cross_entropy_with_logits(logits, label)
|
| 228 |
+
loss.backward()
|
| 229 |
+
opt.step()
|
| 230 |
+
total_loss += loss.item()
|
| 231 |
+
n_batches += 1
|
| 232 |
+
avg = total_loss / max(n_batches, 1)
|
| 233 |
+
logger.info(f"Epoch {ep+1} loss: {avg:.4f}")
|
| 234 |
+
|
| 235 |
+
ckpt = {
|
| 236 |
+
"model": model,
|
| 237 |
+
"item_map": item_map,
|
| 238 |
+
"max_hist_len": max_hist_len,
|
| 239 |
+
"aux_feature_names": aux_cols,
|
| 240 |
+
}
|
| 241 |
+
out_path = rank_dir / "din_ranker.pt"
|
| 242 |
+
torch.save(ckpt, out_path)
|
| 243 |
+
logger.info("DIN ranker saved to %s", out_path)
|
| 244 |
+
|
| 245 |
+
with open(Path(data_dir) / "user_sequences.pkl", "wb") as f:
|
| 246 |
+
pickle.dump(user_sequences, f)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
if __name__ == "__main__":
|
| 250 |
+
import argparse
|
| 251 |
+
|
| 252 |
+
parser = argparse.ArgumentParser(description="Train DIN ranker")
|
| 253 |
+
parser.add_argument("--max_samples", type=int, default=20000)
|
| 254 |
+
parser.add_argument("--epochs", type=int, default=10)
|
| 255 |
+
parser.add_argument("--batch_size", type=int, default=256)
|
| 256 |
+
parser.add_argument("--aux", action="store_true", help="Use aux features from FeatureEngineer")
|
| 257 |
+
args = parser.parse_args()
|
| 258 |
+
|
| 259 |
+
train_din(
|
| 260 |
+
max_samples=args.max_samples,
|
| 261 |
+
epochs=args.epochs,
|
| 262 |
+
batch_size=args.batch_size,
|
| 263 |
+
use_aux=args.aux,
|
| 264 |
+
)
|
scripts/model/train_intent_router.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Train model-based intent classifier for Query Router.
|
| 4 |
+
|
| 5 |
+
Replaces rule-based heuristics with TF-IDF + LogisticRegression (or FastText/DistilBERT).
|
| 6 |
+
Uses synthetic seed data; extend with real labeled queries via --data CSV.
|
| 7 |
+
|
| 8 |
+
Usage:
|
| 9 |
+
python scripts/model/train_intent_router.py
|
| 10 |
+
python scripts/model/train_intent_router.py --data data/intent_labels.csv
|
| 11 |
+
python scripts/model/train_intent_router.py --backend fasttext
|
| 12 |
+
python scripts/model/train_intent_router.py --backend distilbert
|
| 13 |
+
|
| 14 |
+
Output:
|
| 15 |
+
data/model/intent_classifier.pkl (or .bin for fasttext)
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import sys
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
|
| 21 |
+
sys.path.insert(0, str(Path(__file__).resolve().parent.parent.parent))
|
| 22 |
+
|
| 23 |
+
import joblib
|
| 24 |
+
import logging
|
| 25 |
+
|
| 26 |
+
import pandas as pd
|
| 27 |
+
|
| 28 |
+
from src.core.intent_classifier import train_classifier, INTENTS
|
| 29 |
+
|
| 30 |
+
logging.basicConfig(level=logging.INFO, format="%(message)s")
|
| 31 |
+
logger = logging.getLogger(__name__)
|
| 32 |
+
|
| 33 |
+
# Synthetic training data: (query, intent)
|
| 34 |
+
# Extend with real user queries for better generalization
|
| 35 |
+
SEED_DATA = [
|
| 36 |
+
# small_to_big: detail-oriented, plot/review focused
|
| 37 |
+
("book with twist ending", "small_to_big"),
|
| 38 |
+
("unreliable narrator", "small_to_big"),
|
| 39 |
+
("spoiler about the ending", "small_to_big"),
|
| 40 |
+
("what did readers think", "small_to_big"),
|
| 41 |
+
("opinion on the book", "small_to_big"),
|
| 42 |
+
("hidden details in the story", "small_to_big"),
|
| 43 |
+
("did anyone cry reading this", "small_to_big"),
|
| 44 |
+
("review of the book", "small_to_big"),
|
| 45 |
+
("plot twist reveal", "small_to_big"),
|
| 46 |
+
("unreliable narrator twist", "small_to_big"),
|
| 47 |
+
("readers who loved the ending", "small_to_big"),
|
| 48 |
+
("spoiler what happens at the end", "small_to_big"),
|
| 49 |
+
# fast: short keyword queries
|
| 50 |
+
("AI book", "fast"),
|
| 51 |
+
("Python", "fast"),
|
| 52 |
+
("romance", "fast"),
|
| 53 |
+
("machine learning", "fast"),
|
| 54 |
+
("science fiction", "fast"),
|
| 55 |
+
("best AI book", "fast"),
|
| 56 |
+
("Python programming", "fast"),
|
| 57 |
+
("self help", "fast"),
|
| 58 |
+
("business", "fast"),
|
| 59 |
+
("fiction", "fast"),
|
| 60 |
+
("thriller", "fast"),
|
| 61 |
+
("mystery novel", "fast"),
|
| 62 |
+
("finance", "fast"),
|
| 63 |
+
("history", "fast"),
|
| 64 |
+
("psychology", "fast"),
|
| 65 |
+
("data science", "fast"),
|
| 66 |
+
("cooking", "fast"),
|
| 67 |
+
("music", "fast"),
|
| 68 |
+
("art", "fast"),
|
| 69 |
+
("philosophy", "fast"),
|
| 70 |
+
# deep: natural language, complex queries
|
| 71 |
+
("What are the best books about artificial intelligence for beginners", "deep"),
|
| 72 |
+
("I'm looking for something similar to Harry Potter", "deep"),
|
| 73 |
+
("Books that help you understand machine learning", "deep"),
|
| 74 |
+
("Recommend me a book like Sapiens but about technology", "deep"),
|
| 75 |
+
("I want to learn about psychology and human behavior", "deep"),
|
| 76 |
+
("What should I read if I liked 1984", "deep"),
|
| 77 |
+
("Looking for books on startup founding and entrepreneurship", "deep"),
|
| 78 |
+
("Can you suggest books about climate change and sustainability", "deep"),
|
| 79 |
+
("I need a book that explains quantum physics simply", "deep"),
|
| 80 |
+
("Books for someone who wants to improve their writing skills", "deep"),
|
| 81 |
+
("What are some good fiction books set in Japan", "deep"),
|
| 82 |
+
("Recommendations for someone getting into philosophy", "deep"),
|
| 83 |
+
("Books that discuss the future of work and automation", "deep"),
|
| 84 |
+
("I'm interested in biographies of scientists", "deep"),
|
| 85 |
+
("Something light and funny for a long flight", "deep"),
|
| 86 |
+
("Books about the history of mathematics", "deep"),
|
| 87 |
+
("Recommend me novels with strong female protagonists", "deep"),
|
| 88 |
+
("What to read to understand economics", "deep"),
|
| 89 |
+
("Books on meditation and mindfulness", "deep"),
|
| 90 |
+
]
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def load_training_data(data_path: Path | None) -> tuple[list[str], list[str]]:
|
| 94 |
+
"""Load (queries, labels) from SEED_DATA + optional CSV."""
|
| 95 |
+
queries = [q for q, _ in SEED_DATA]
|
| 96 |
+
labels = [l for _, l in SEED_DATA]
|
| 97 |
+
|
| 98 |
+
if data_path and data_path.exists():
|
| 99 |
+
df = pd.read_csv(data_path)
|
| 100 |
+
q_col = "query" if "query" in df.columns else df.columns[0]
|
| 101 |
+
l_col = "intent" if "intent" in df.columns else df.columns[1]
|
| 102 |
+
extra_q = df[q_col].astype(str).tolist()
|
| 103 |
+
extra_l = df[l_col].astype(str).tolist()
|
| 104 |
+
queries.extend(extra_q)
|
| 105 |
+
labels.extend(extra_l)
|
| 106 |
+
logger.info("Loaded %d extra samples from %s", len(extra_q), data_path)
|
| 107 |
+
|
| 108 |
+
return queries, labels
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def main():
|
| 112 |
+
import argparse
|
| 113 |
+
parser = argparse.ArgumentParser(description="Train intent classifier")
|
| 114 |
+
parser.add_argument("--data", type=Path, default=None, help="CSV with query,intent columns")
|
| 115 |
+
parser.add_argument("--backend", choices=["tfidf", "fasttext", "distilbert"], default="tfidf")
|
| 116 |
+
args = parser.parse_args()
|
| 117 |
+
|
| 118 |
+
project_root = Path(__file__).resolve().parent.parent.parent
|
| 119 |
+
out_dir = project_root / "data" / "model"
|
| 120 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 121 |
+
|
| 122 |
+
queries, labels = load_training_data(args.data)
|
| 123 |
+
|
| 124 |
+
logger.info("Training intent classifier (%s) on %d samples...", args.backend, len(queries))
|
| 125 |
+
result = train_classifier(queries, labels, backend=args.backend)
|
| 126 |
+
|
| 127 |
+
if args.backend == "fasttext":
|
| 128 |
+
out_path = out_dir / "intent_classifier.bin"
|
| 129 |
+
result.save_model(str(out_path))
|
| 130 |
+
else:
|
| 131 |
+
out_path = out_dir / "intent_classifier.pkl"
|
| 132 |
+
if args.backend == "distilbert":
|
| 133 |
+
joblib.dump(result, out_path) # dict with pipeline, backend, etc.
|
| 134 |
+
else:
|
| 135 |
+
joblib.dump({"pipeline": result, "backend": "tfidf"}, out_path)
|
| 136 |
+
|
| 137 |
+
logger.info("Saved to %s", out_path)
|
| 138 |
+
|
| 139 |
+
# Quick sanity check
|
| 140 |
+
for intent in INTENTS:
|
| 141 |
+
sample = next((q for q, l in zip(queries, labels) if l == intent), None)
|
| 142 |
+
if sample:
|
| 143 |
+
if args.backend == "fasttext":
|
| 144 |
+
pred = result.predict(sample)[0][0].replace("__label__", "")
|
| 145 |
+
elif args.backend == "distilbert":
|
| 146 |
+
from transformers import pipeline
|
| 147 |
+
pipe = pipeline("zero-shot-classification", model="distilbert-base-uncased", device=-1)
|
| 148 |
+
pred = pipe(sample, INTENTS, multi_label=False)["labels"][0]
|
| 149 |
+
else:
|
| 150 |
+
pred = result.predict([sample])[0]
|
| 151 |
+
ok = "✓" if pred == intent else "✗"
|
| 152 |
+
logger.info(" %s %s: %r -> %s", ok, intent, sample[:40], pred)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
if __name__ == "__main__":
|
| 156 |
+
main()
|
scripts/model/train_ranker.py
CHANGED
|
@@ -15,18 +15,15 @@ Input:
|
|
| 15 |
- data/rec/train.csv (for fallback random negatives)
|
| 16 |
- data/model/recall/*.pkl (recall models for hard negative mining)
|
| 17 |
|
| 18 |
-
|
| 19 |
-
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
-
|
| 23 |
-
- data/model/ranking/xgb_ranker.json (full retrained XGB)
|
| 24 |
-
- data/model/ranking/stacking_meta.pkl (LogisticRegression meta-model)
|
| 25 |
|
| 26 |
Negative Sampling Strategy:
|
| 27 |
- Hard negatives: items from recall results that are NOT the positive
|
| 28 |
- Random negatives: fill remaining slots if recall returns too few
|
| 29 |
-
- This teaches the ranker to distinguish between "close but wrong" vs "right"
|
| 30 |
"""
|
| 31 |
|
| 32 |
import sys
|
|
|
|
| 15 |
- data/rec/train.csv (for fallback random negatives)
|
| 16 |
- data/model/recall/*.pkl (recall models for hard negative mining)
|
| 17 |
|
| 18 |
+
TIME-SPLIT (no leakage):
|
| 19 |
+
- Recall models (SASRec, etc.) are trained on train.csv only.
|
| 20 |
+
- Ranking uses val.csv for labels; recall for hard negatives.
|
| 21 |
+
- sasrec_score and user_seq_emb come from train-only SASRec.
|
| 22 |
+
- Pipeline order: split -> build_sequences(train-only) -> recall(train) -> ranker(val).
|
|
|
|
|
|
|
| 23 |
|
| 24 |
Negative Sampling Strategy:
|
| 25 |
- Hard negatives: items from recall results that are NOT the positive
|
| 26 |
- Random negatives: fill remaining slots if recall returns too few
|
|
|
|
| 27 |
"""
|
| 28 |
|
| 29 |
import sys
|
scripts/run_pipeline.py
CHANGED
|
@@ -44,6 +44,7 @@ class Pipeline:
|
|
| 44 |
skip_models: bool = False,
|
| 45 |
skip_index: bool = False,
|
| 46 |
stacking: bool = False,
|
|
|
|
| 47 |
):
|
| 48 |
self.project_root = Path(project_root)
|
| 49 |
self.data_dir = self.project_root / "data"
|
|
@@ -53,6 +54,7 @@ class Pipeline:
|
|
| 53 |
self.skip_models = skip_models
|
| 54 |
self.skip_index = skip_index
|
| 55 |
self.stacking = stacking
|
|
|
|
| 56 |
|
| 57 |
def _run_step(self, name: str, fn, *args, **kwargs):
|
| 58 |
"""Run a step with timing log."""
|
|
@@ -153,8 +155,19 @@ class Pipeline:
|
|
| 153 |
from scripts.model.train_ranker import train_ranker, train_stacking
|
| 154 |
self._run_step("Train Ranker", train_stacking if self.stacking else train_ranker)
|
| 155 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
def run_evaluation(self) -> None:
|
| 157 |
-
"""Stage 5: Validation."""
|
| 158 |
def _validate():
|
| 159 |
from scripts.data.validate_data import (
|
| 160 |
validate_raw, validate_processed, validate_rec,
|
|
@@ -168,6 +181,18 @@ class Pipeline:
|
|
| 168 |
|
| 169 |
self._run_step("Validate pipeline", _validate)
|
| 170 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
def run(self, stage: str = "all") -> None:
|
| 172 |
"""Execute full pipeline: Data Cleaning -> Training -> Evaluation."""
|
| 173 |
logger.info("=" * 60)
|
|
@@ -201,6 +226,7 @@ def main():
|
|
| 201 |
parser.add_argument("--validate-only", action="store_true", help="Only run validation")
|
| 202 |
parser.add_argument("--device", default=None, help="Device for ML (cpu/cuda/mps)")
|
| 203 |
parser.add_argument("--stacking", action="store_true", help="Enable stacking ranker")
|
|
|
|
| 204 |
args = parser.parse_args()
|
| 205 |
|
| 206 |
if args.validate_only:
|
|
@@ -213,6 +239,7 @@ def main():
|
|
| 213 |
skip_models=args.skip_models,
|
| 214 |
skip_index=args.skip_index,
|
| 215 |
stacking=args.stacking,
|
|
|
|
| 216 |
)
|
| 217 |
pipeline.run(stage=args.stage)
|
| 218 |
|
|
|
|
| 44 |
skip_models: bool = False,
|
| 45 |
skip_index: bool = False,
|
| 46 |
stacking: bool = False,
|
| 47 |
+
train_din: bool = False,
|
| 48 |
):
|
| 49 |
self.project_root = Path(project_root)
|
| 50 |
self.data_dir = self.project_root / "data"
|
|
|
|
| 54 |
self.skip_models = skip_models
|
| 55 |
self.skip_index = skip_index
|
| 56 |
self.stacking = stacking
|
| 57 |
+
self.train_din = train_din
|
| 58 |
|
| 59 |
def _run_step(self, name: str, fn, *args, **kwargs):
|
| 60 |
"""Run a step with timing log."""
|
|
|
|
| 155 |
from scripts.model.train_ranker import train_ranker, train_stacking
|
| 156 |
self._run_step("Train Ranker", train_stacking if self.stacking else train_ranker)
|
| 157 |
|
| 158 |
+
from scripts.model.train_intent_router import main as train_intent
|
| 159 |
+
self._run_step("Train intent classifier", train_intent)
|
| 160 |
+
|
| 161 |
+
if getattr(self, "train_din", False):
|
| 162 |
+
from scripts.model.train_din_ranker import train_din
|
| 163 |
+
self._run_step("Train DIN ranker", lambda: train_din(
|
| 164 |
+
data_dir=str(self.rec_dir),
|
| 165 |
+
model_dir=str(self.model_dir),
|
| 166 |
+
recall_dir=str(self.model_dir / "recall"),
|
| 167 |
+
))
|
| 168 |
+
|
| 169 |
def run_evaluation(self) -> None:
|
| 170 |
+
"""Stage 5: Validation + RAG Golden Test Set (if exists)."""
|
| 171 |
def _validate():
|
| 172 |
from scripts.data.validate_data import (
|
| 173 |
validate_raw, validate_processed, validate_rec,
|
|
|
|
| 181 |
|
| 182 |
self._run_step("Validate pipeline", _validate)
|
| 183 |
|
| 184 |
+
# RAG Golden Test Set evaluation (optional)
|
| 185 |
+
golden = self.rec_dir.parent / "rag_golden.csv"
|
| 186 |
+
if not golden.exists():
|
| 187 |
+
golden = self.rec_dir.parent / "rag_golden.example.csv"
|
| 188 |
+
if golden.exists():
|
| 189 |
+
def _run_rag_eval():
|
| 190 |
+
from scripts.model.evaluate_rag import evaluate_rag
|
| 191 |
+
m = evaluate_rag(str(golden))
|
| 192 |
+
logger.info("RAG Accuracy@%d: %.4f Recall@%d: %.4f MRR@%d: %.4f",
|
| 193 |
+
m["top_k"], m["accuracy_at_k"], m["top_k"], m["recall_at_k"], m["top_k"], m["mrr_at_k"])
|
| 194 |
+
self._run_step("RAG Golden Test Set", _run_rag_eval)
|
| 195 |
+
|
| 196 |
def run(self, stage: str = "all") -> None:
|
| 197 |
"""Execute full pipeline: Data Cleaning -> Training -> Evaluation."""
|
| 198 |
logger.info("=" * 60)
|
|
|
|
| 226 |
parser.add_argument("--validate-only", action="store_true", help="Only run validation")
|
| 227 |
parser.add_argument("--device", default=None, help="Device for ML (cpu/cuda/mps)")
|
| 228 |
parser.add_argument("--stacking", action="store_true", help="Enable stacking ranker")
|
| 229 |
+
parser.add_argument("--din", action="store_true", help="Train DIN ranker (deep model)")
|
| 230 |
args = parser.parse_args()
|
| 231 |
|
| 232 |
if args.validate_only:
|
|
|
|
| 239 |
skip_models=args.skip_models,
|
| 240 |
skip_index=args.skip_index,
|
| 241 |
stacking=args.stacking,
|
| 242 |
+
train_din=args.din,
|
| 243 |
)
|
| 244 |
pipeline.run(stage=args.stage)
|
| 245 |
|
src/core/freshness_monitor.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Data Freshness Monitor.
|
| 3 |
+
|
| 4 |
+
Provides insights into the freshness of the local book database:
|
| 5 |
+
- Distribution of books by publication year
|
| 6 |
+
- Detection of data staleness
|
| 7 |
+
- Recommendations for when to trigger updates
|
| 8 |
+
|
| 9 |
+
This module helps the system decide when to rely on local data vs.
|
| 10 |
+
triggering external API fallbacks.
|
| 11 |
+
"""
|
| 12 |
+
from datetime import datetime
|
| 13 |
+
from typing import Optional
|
| 14 |
+
|
| 15 |
+
from src.core.metadata_store import metadata_store
|
| 16 |
+
from src.utils import setup_logger
|
| 17 |
+
|
| 18 |
+
logger = setup_logger(__name__)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class FreshnessMonitor:
|
| 22 |
+
"""
|
| 23 |
+
Monitor data freshness and provide staleness detection.
|
| 24 |
+
|
| 25 |
+
Usage:
|
| 26 |
+
monitor = FreshnessMonitor()
|
| 27 |
+
stats = monitor.get_data_stats()
|
| 28 |
+
if monitor.is_stale_for_query("latest 2025 books"):
|
| 29 |
+
# Trigger web search fallback
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
# Years considered "recent" for freshness calculations
|
| 33 |
+
RECENT_YEARS_THRESHOLD = 2
|
| 34 |
+
|
| 35 |
+
def __init__(self):
|
| 36 |
+
self._cache = {}
|
| 37 |
+
self._cache_timestamp = None
|
| 38 |
+
self._cache_ttl_seconds = 300 # 5 minutes
|
| 39 |
+
|
| 40 |
+
def _is_cache_valid(self) -> bool:
|
| 41 |
+
"""Check if cached stats are still valid."""
|
| 42 |
+
if not self._cache or not self._cache_timestamp:
|
| 43 |
+
return False
|
| 44 |
+
age = (datetime.now() - self._cache_timestamp).total_seconds()
|
| 45 |
+
return age < self._cache_ttl_seconds
|
| 46 |
+
|
| 47 |
+
def get_data_stats(self, force_refresh: bool = False) -> dict:
|
| 48 |
+
"""
|
| 49 |
+
Get comprehensive statistics about data freshness.
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
Dict with:
|
| 53 |
+
- total_books: Total number of books in database
|
| 54 |
+
- newest_year: Year of most recently published book
|
| 55 |
+
- oldest_year: Year of oldest book
|
| 56 |
+
- books_by_year: Dict mapping year -> count
|
| 57 |
+
- recent_books_count: Books published in last N years
|
| 58 |
+
- data_cutoff_year: Effective "knowledge cutoff" year
|
| 59 |
+
- freshness_score: 0-100 score indicating data freshness
|
| 60 |
+
"""
|
| 61 |
+
if not force_refresh and self._is_cache_valid():
|
| 62 |
+
return self._cache
|
| 63 |
+
|
| 64 |
+
stats = {
|
| 65 |
+
"total_books": 0,
|
| 66 |
+
"newest_year": None,
|
| 67 |
+
"oldest_year": None,
|
| 68 |
+
"books_by_year": {},
|
| 69 |
+
"recent_books_count": 0,
|
| 70 |
+
"data_cutoff_year": None,
|
| 71 |
+
"freshness_score": 0,
|
| 72 |
+
"last_checked": datetime.now().isoformat(),
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
try:
|
| 76 |
+
stats["total_books"] = metadata_store.get_book_count()
|
| 77 |
+
stats["books_by_year"] = metadata_store.get_books_by_year_distribution()
|
| 78 |
+
|
| 79 |
+
if stats["books_by_year"]:
|
| 80 |
+
years = sorted(stats["books_by_year"].keys())
|
| 81 |
+
stats["newest_year"] = max(years)
|
| 82 |
+
stats["oldest_year"] = min(years)
|
| 83 |
+
stats["data_cutoff_year"] = stats["newest_year"]
|
| 84 |
+
|
| 85 |
+
# Count recent books (last N years)
|
| 86 |
+
current_year = datetime.now().year
|
| 87 |
+
recent_threshold = current_year - self.RECENT_YEARS_THRESHOLD
|
| 88 |
+
stats["recent_books_count"] = sum(
|
| 89 |
+
count for year, count in stats["books_by_year"].items()
|
| 90 |
+
if year >= recent_threshold
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
# Calculate freshness score (0-100)
|
| 94 |
+
# Based on: newest year relative to current year
|
| 95 |
+
years_behind = current_year - (stats["newest_year"] or current_year)
|
| 96 |
+
stats["freshness_score"] = max(0, 100 - (years_behind * 25))
|
| 97 |
+
|
| 98 |
+
self._cache = stats
|
| 99 |
+
self._cache_timestamp = datetime.now()
|
| 100 |
+
|
| 101 |
+
except Exception as e:
|
| 102 |
+
logger.error(f"FreshnessMonitor.get_data_stats failed: {e}")
|
| 103 |
+
|
| 104 |
+
return stats
|
| 105 |
+
|
| 106 |
+
def is_stale(self, target_year: Optional[int] = None) -> bool:
|
| 107 |
+
"""
|
| 108 |
+
Check if local data is too old for a given target year.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
target_year: Year the user is asking about (default: current year)
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
True if data is stale and web fallback should be triggered
|
| 115 |
+
"""
|
| 116 |
+
if target_year is None:
|
| 117 |
+
target_year = datetime.now().year
|
| 118 |
+
|
| 119 |
+
stats = self.get_data_stats()
|
| 120 |
+
newest_year = stats.get("newest_year")
|
| 121 |
+
|
| 122 |
+
if newest_year is None:
|
| 123 |
+
return True # No data at all
|
| 124 |
+
|
| 125 |
+
# Stale if target year is newer than our newest data
|
| 126 |
+
return target_year > newest_year
|
| 127 |
+
|
| 128 |
+
def is_stale_for_query(self, query: str) -> bool:
|
| 129 |
+
"""
|
| 130 |
+
Analyze a query and determine if data is stale for it.
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
query: User's search query
|
| 134 |
+
|
| 135 |
+
Returns:
|
| 136 |
+
True if web fallback should be triggered
|
| 137 |
+
"""
|
| 138 |
+
from src.core.web_search import extract_year_from_query
|
| 139 |
+
|
| 140 |
+
target_year = extract_year_from_query(query)
|
| 141 |
+
|
| 142 |
+
if target_year is None:
|
| 143 |
+
# No year requirement - check freshness score
|
| 144 |
+
stats = self.get_data_stats()
|
| 145 |
+
# Trigger fallback if data is more than 2 years old
|
| 146 |
+
return stats.get("freshness_score", 100) < 50
|
| 147 |
+
|
| 148 |
+
return self.is_stale(target_year)
|
| 149 |
+
|
| 150 |
+
def get_coverage_for_year(self, year: int) -> dict:
|
| 151 |
+
"""
|
| 152 |
+
Get coverage statistics for a specific year.
|
| 153 |
+
|
| 154 |
+
Returns:
|
| 155 |
+
Dict with: count, percentage of total, is_well_covered
|
| 156 |
+
"""
|
| 157 |
+
stats = self.get_data_stats()
|
| 158 |
+
year_count = stats["books_by_year"].get(year, 0)
|
| 159 |
+
total = stats["total_books"] or 1
|
| 160 |
+
|
| 161 |
+
return {
|
| 162 |
+
"year": year,
|
| 163 |
+
"count": year_count,
|
| 164 |
+
"percentage": round(year_count / total * 100, 2),
|
| 165 |
+
"is_well_covered": year_count >= 100, # Arbitrary threshold
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
def recommend_update_categories(self) -> list[str]:
|
| 169 |
+
"""
|
| 170 |
+
Recommend categories that should be updated.
|
| 171 |
+
|
| 172 |
+
Returns:
|
| 173 |
+
List of category names that need fresh data
|
| 174 |
+
"""
|
| 175 |
+
# This would require category-level year tracking
|
| 176 |
+
# For now, return common categories that benefit from freshness
|
| 177 |
+
return [
|
| 178 |
+
"fiction",
|
| 179 |
+
"thriller",
|
| 180 |
+
"science fiction",
|
| 181 |
+
"fantasy",
|
| 182 |
+
"mystery",
|
| 183 |
+
"self-help",
|
| 184 |
+
"business",
|
| 185 |
+
]
|
| 186 |
+
|
| 187 |
+
def get_summary(self) -> str:
|
| 188 |
+
"""
|
| 189 |
+
Get a human-readable summary of data freshness.
|
| 190 |
+
|
| 191 |
+
Returns:
|
| 192 |
+
Formatted string describing data freshness status
|
| 193 |
+
"""
|
| 194 |
+
stats = self.get_data_stats()
|
| 195 |
+
|
| 196 |
+
lines = [
|
| 197 |
+
f"Data Freshness Report",
|
| 198 |
+
f"=" * 40,
|
| 199 |
+
f"Total books: {stats['total_books']:,}",
|
| 200 |
+
f"Newest book year: {stats['newest_year'] or 'Unknown'}",
|
| 201 |
+
f"Data cutoff: {stats['data_cutoff_year'] or 'Unknown'}",
|
| 202 |
+
f"Recent books (last {self.RECENT_YEARS_THRESHOLD} years): {stats['recent_books_count']:,}",
|
| 203 |
+
f"Freshness score: {stats['freshness_score']}/100",
|
| 204 |
+
]
|
| 205 |
+
|
| 206 |
+
current_year = datetime.now().year
|
| 207 |
+
if stats["newest_year"] and stats["newest_year"] < current_year:
|
| 208 |
+
years_behind = current_year - stats["newest_year"]
|
| 209 |
+
lines.append(f"")
|
| 210 |
+
lines.append(f"WARNING: Data is {years_behind} year(s) behind current year.")
|
| 211 |
+
lines.append(f"Consider running: python scripts/data/fetch_new_books.py --year {current_year}")
|
| 212 |
+
|
| 213 |
+
return "\n".join(lines)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
# Global instance
|
| 217 |
+
freshness_monitor = FreshnessMonitor()
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
# Convenience function for quick checks
|
| 221 |
+
def is_data_fresh_enough(query: str) -> bool:
|
| 222 |
+
"""
|
| 223 |
+
Quick check if local data is fresh enough for a query.
|
| 224 |
+
|
| 225 |
+
Args:
|
| 226 |
+
query: User's search query
|
| 227 |
+
|
| 228 |
+
Returns:
|
| 229 |
+
True if local data is sufficient, False if web fallback recommended
|
| 230 |
+
"""
|
| 231 |
+
return not freshness_monitor.is_stale_for_query(query)
|
src/core/intent_classifier.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Model-based intent classifier for Query Router.
|
| 3 |
+
|
| 4 |
+
Replaces brittle rule-based heuristics with a trained classifier.
|
| 5 |
+
Backends: tfidf (default), fasttext, distilbert.
|
| 6 |
+
|
| 7 |
+
Intents: small_to_big (detail), fast (keyword), deep (natural language)
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import Optional
|
| 13 |
+
|
| 14 |
+
import joblib
|
| 15 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
| 16 |
+
from sklearn.linear_model import LogisticRegression
|
| 17 |
+
from sklearn.pipeline import Pipeline
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
INTENTS = ["small_to_big", "fast", "deep"]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class IntentClassifier:
|
| 25 |
+
"""
|
| 26 |
+
Intent classifier with pluggable backends:
|
| 27 |
+
- tfidf: TF-IDF + LogisticRegression (~1–2ms)
|
| 28 |
+
- fasttext: FastText (~1ms, requires fasttext package)
|
| 29 |
+
- distilbert: Zero-shot DistilBERT (~50–100ms, higher accuracy)
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(self, model_path: Optional[Path] = None):
|
| 33 |
+
self.pipeline: Optional[Pipeline] = None
|
| 34 |
+
self._fasttext_model = None
|
| 35 |
+
self._distilbert_pipeline = None
|
| 36 |
+
self._backend = "tfidf"
|
| 37 |
+
self.model_path = Path(model_path) if model_path else None
|
| 38 |
+
|
| 39 |
+
def load(self, path: Optional[Path] = None) -> bool:
|
| 40 |
+
"""Load trained model from disk."""
|
| 41 |
+
p = path or self.model_path
|
| 42 |
+
if not p:
|
| 43 |
+
return False
|
| 44 |
+
p = Path(p)
|
| 45 |
+
base = p.parent if p.suffix in (".pkl", ".bin") else p
|
| 46 |
+
pkl_path = p if p.suffix == ".pkl" else base / "intent_classifier.pkl"
|
| 47 |
+
bin_path = p if p.suffix == ".bin" else base / "intent_classifier.bin"
|
| 48 |
+
|
| 49 |
+
# Try .pkl first (tfidf or distilbert)
|
| 50 |
+
if pkl_path.exists():
|
| 51 |
+
try:
|
| 52 |
+
data = joblib.load(pkl_path)
|
| 53 |
+
if isinstance(data, dict):
|
| 54 |
+
self.pipeline = data.get("pipeline")
|
| 55 |
+
self._backend = data.get("backend", "tfidf")
|
| 56 |
+
if self._backend == "distilbert":
|
| 57 |
+
self._load_distilbert(data)
|
| 58 |
+
elif self.pipeline is None and self._backend == "tfidf":
|
| 59 |
+
self.pipeline = data
|
| 60 |
+
else:
|
| 61 |
+
self.pipeline = data
|
| 62 |
+
self.model_path = pkl_path
|
| 63 |
+
logger.info("Intent classifier loaded from %s (backend=%s)", pkl_path, self._backend)
|
| 64 |
+
return True
|
| 65 |
+
except Exception as e:
|
| 66 |
+
logger.warning("Failed to load intent classifier: %s", e)
|
| 67 |
+
|
| 68 |
+
# Try .bin (FastText)
|
| 69 |
+
if bin_path.exists():
|
| 70 |
+
try:
|
| 71 |
+
import fasttext
|
| 72 |
+
self._fasttext_model = fasttext.load_model(str(bin_path))
|
| 73 |
+
self._backend = "fasttext"
|
| 74 |
+
self.model_path = bin_path
|
| 75 |
+
logger.info("Intent classifier loaded from %s (FastText)", bin_path)
|
| 76 |
+
return True
|
| 77 |
+
except ImportError:
|
| 78 |
+
logger.warning("FastText not installed; pip install fasttext")
|
| 79 |
+
except Exception as e:
|
| 80 |
+
logger.warning("Failed to load FastText: %s", e)
|
| 81 |
+
|
| 82 |
+
return False
|
| 83 |
+
|
| 84 |
+
def _load_distilbert(self, data: dict) -> None:
|
| 85 |
+
"""Lazy-load DistilBERT pipeline from saved config."""
|
| 86 |
+
model_name = data.get("distilbert_model", "distilbert-base-uncased")
|
| 87 |
+
try:
|
| 88 |
+
from transformers import pipeline
|
| 89 |
+
self._distilbert_pipeline = pipeline(
|
| 90 |
+
"zero-shot-classification",
|
| 91 |
+
model=model_name,
|
| 92 |
+
device=-1,
|
| 93 |
+
)
|
| 94 |
+
except Exception as e:
|
| 95 |
+
logger.warning("DistilBERT pipeline load failed: %s", e)
|
| 96 |
+
self.pipeline = None # Use distilbert, not sklearn pipeline
|
| 97 |
+
|
| 98 |
+
def predict(self, query: str) -> str:
|
| 99 |
+
"""Predict intent for a query. Returns one of small_to_big, fast, deep."""
|
| 100 |
+
q = query.strip()
|
| 101 |
+
if not q:
|
| 102 |
+
return "deep"
|
| 103 |
+
|
| 104 |
+
if self._fasttext_model is not None:
|
| 105 |
+
pred = self._fasttext_model.predict(q)
|
| 106 |
+
return pred[0][0].replace("__label__", "")
|
| 107 |
+
|
| 108 |
+
if self._distilbert_pipeline is not None:
|
| 109 |
+
out = self._distilbert_pipeline(q, INTENTS, multi_label=False)
|
| 110 |
+
return out["labels"][0]
|
| 111 |
+
|
| 112 |
+
if self.pipeline is None:
|
| 113 |
+
raise RuntimeError("Intent classifier not loaded; call load() first")
|
| 114 |
+
return str(self.pipeline.predict([q])[0])
|
| 115 |
+
|
| 116 |
+
def predict_proba(self, query: str) -> dict[str, float]:
|
| 117 |
+
"""Return intent probabilities for debugging."""
|
| 118 |
+
q = query.strip()
|
| 119 |
+
if not q:
|
| 120 |
+
return {i: 1.0 / len(INTENTS) for i in INTENTS}
|
| 121 |
+
|
| 122 |
+
if self._fasttext_model is not None:
|
| 123 |
+
pred = self._fasttext_model.predict(q, k=len(INTENTS))
|
| 124 |
+
return dict(zip([l.replace("__label__", "") for l in pred[0]], pred[1]))
|
| 125 |
+
|
| 126 |
+
if self._distilbert_pipeline is not None:
|
| 127 |
+
out = self._distilbert_pipeline(q, INTENTS, multi_label=False)
|
| 128 |
+
return dict(zip(out["labels"], out["scores"]))
|
| 129 |
+
|
| 130 |
+
if self.pipeline is None:
|
| 131 |
+
raise RuntimeError("Intent classifier not loaded")
|
| 132 |
+
probs = self.pipeline.predict_proba([q])[0]
|
| 133 |
+
last_step = self.pipeline.steps[-1][1]
|
| 134 |
+
classes = getattr(last_step, "classes_", INTENTS)
|
| 135 |
+
return dict(zip(classes, probs))
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def train_classifier(
|
| 139 |
+
queries: list[str],
|
| 140 |
+
labels: list[str],
|
| 141 |
+
max_features: int = 5000,
|
| 142 |
+
C: float = 1.0,
|
| 143 |
+
backend: str = "tfidf",
|
| 144 |
+
):
|
| 145 |
+
"""
|
| 146 |
+
Train intent classifier. Returns pipeline (tfidf), model (fasttext), or dict (distilbert).
|
| 147 |
+
"""
|
| 148 |
+
if backend == "fasttext":
|
| 149 |
+
return _train_fasttext(queries, labels)
|
| 150 |
+
if backend == "distilbert":
|
| 151 |
+
return _train_distilbert(queries, labels)
|
| 152 |
+
# tfidf default
|
| 153 |
+
pipeline = Pipeline([
|
| 154 |
+
("tfidf", TfidfVectorizer(
|
| 155 |
+
max_features=max_features,
|
| 156 |
+
ngram_range=(1, 2),
|
| 157 |
+
min_df=1,
|
| 158 |
+
lowercase=True,
|
| 159 |
+
)),
|
| 160 |
+
("clf", LogisticRegression(
|
| 161 |
+
C=C,
|
| 162 |
+
max_iter=500,
|
| 163 |
+
class_weight="balanced",
|
| 164 |
+
random_state=42,
|
| 165 |
+
)),
|
| 166 |
+
])
|
| 167 |
+
pipeline.fit(queries, labels)
|
| 168 |
+
return pipeline
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def _train_fasttext(queries: list[str], labels: list[str]):
|
| 172 |
+
"""Train FastText classifier. Requires fasttext package."""
|
| 173 |
+
try:
|
| 174 |
+
import fasttext
|
| 175 |
+
import tempfile
|
| 176 |
+
import os
|
| 177 |
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f:
|
| 178 |
+
for q, l in zip(queries, labels):
|
| 179 |
+
line = q.replace("\n", " ").strip()
|
| 180 |
+
f.write(f"__label__{l} {line}\n")
|
| 181 |
+
path = f.name
|
| 182 |
+
model = fasttext.train_supervised(path, epoch=25, lr=0.5, wordNgrams=2)
|
| 183 |
+
os.unlink(path)
|
| 184 |
+
return model
|
| 185 |
+
except ImportError:
|
| 186 |
+
raise RuntimeError("FastText not installed: pip install fasttext")
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def _train_distilbert(queries: list[str], labels: list[str]) -> dict:
|
| 190 |
+
"""DistilBERT zero-shot: creates pipeline (no training). Saves config for inference."""
|
| 191 |
+
try:
|
| 192 |
+
from transformers import pipeline
|
| 193 |
+
pipe = pipeline(
|
| 194 |
+
"zero-shot-classification",
|
| 195 |
+
model="distilbert-base-uncased",
|
| 196 |
+
device=-1,
|
| 197 |
+
)
|
| 198 |
+
return {
|
| 199 |
+
"backend": "distilbert",
|
| 200 |
+
"distilbert_model": "distilbert-base-uncased",
|
| 201 |
+
"intents": INTENTS,
|
| 202 |
+
}
|
| 203 |
+
except Exception as e:
|
| 204 |
+
raise RuntimeError(f"DistilBERT setup failed: {e}")
|
src/core/metadata_store.py
CHANGED
|
@@ -142,6 +142,148 @@ class MetadataStore:
|
|
| 142 |
logger.error(f"MetadataStore insert_book failed: {e}")
|
| 143 |
return False
|
| 144 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
def load_books_processed(self): pass
|
| 146 |
def load_train_data(self): pass
|
| 147 |
|
|
|
|
| 142 |
logger.error(f"MetadataStore insert_book failed: {e}")
|
| 143 |
return False
|
| 144 |
|
| 145 |
+
def insert_book_with_fts(self, row: Dict[str, Any]) -> bool:
|
| 146 |
+
"""
|
| 147 |
+
Insert a new book into both main table AND FTS5 index.
|
| 148 |
+
|
| 149 |
+
This enables incremental indexing - new books are immediately searchable
|
| 150 |
+
via keyword search without requiring a full index rebuild.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
row: Book data dict with keys: isbn13, title, description, authors, simple_categories, etc.
|
| 154 |
+
|
| 155 |
+
Returns:
|
| 156 |
+
True if successful, False otherwise
|
| 157 |
+
"""
|
| 158 |
+
conn = self.connection
|
| 159 |
+
if not conn:
|
| 160 |
+
return False
|
| 161 |
+
|
| 162 |
+
try:
|
| 163 |
+
# 1. Insert into main books table
|
| 164 |
+
if not self.insert_book(row):
|
| 165 |
+
return False
|
| 166 |
+
|
| 167 |
+
# 2. Insert into FTS5 index
|
| 168 |
+
# FTS5 columns: isbn13, title, description, authors, simple_categories
|
| 169 |
+
isbn13 = str(row.get("isbn13", ""))
|
| 170 |
+
title = str(row.get("title", ""))
|
| 171 |
+
description = str(row.get("description", ""))
|
| 172 |
+
authors = str(row.get("authors", ""))
|
| 173 |
+
categories = str(row.get("simple_categories", ""))
|
| 174 |
+
|
| 175 |
+
# Check if FTS5 table exists
|
| 176 |
+
cursor = conn.cursor()
|
| 177 |
+
cursor.execute(
|
| 178 |
+
"SELECT name FROM sqlite_master WHERE type='table' AND name='books_fts'"
|
| 179 |
+
)
|
| 180 |
+
if not cursor.fetchone():
|
| 181 |
+
logger.warning("MetadataStore: FTS5 table 'books_fts' not found. Skipping FTS index.")
|
| 182 |
+
return True # Main insert succeeded, FTS just not available
|
| 183 |
+
|
| 184 |
+
# Insert into FTS5 (use INSERT OR REPLACE to handle updates)
|
| 185 |
+
cursor.execute(
|
| 186 |
+
"""
|
| 187 |
+
INSERT OR REPLACE INTO books_fts (isbn13, title, description, authors, simple_categories)
|
| 188 |
+
VALUES (?, ?, ?, ?, ?)
|
| 189 |
+
""",
|
| 190 |
+
(isbn13, title, description, authors, categories)
|
| 191 |
+
)
|
| 192 |
+
conn.commit()
|
| 193 |
+
|
| 194 |
+
logger.info(f"MetadataStore: Inserted book {isbn13} into FTS5 index")
|
| 195 |
+
return True
|
| 196 |
+
|
| 197 |
+
except sqlite3.OperationalError as e:
|
| 198 |
+
# FTS5 might not support OR REPLACE, try without
|
| 199 |
+
if "REPLACE" in str(e):
|
| 200 |
+
try:
|
| 201 |
+
cursor = conn.cursor()
|
| 202 |
+
cursor.execute(
|
| 203 |
+
"""
|
| 204 |
+
INSERT INTO books_fts (isbn13, title, description, authors, simple_categories)
|
| 205 |
+
VALUES (?, ?, ?, ?, ?)
|
| 206 |
+
""",
|
| 207 |
+
(isbn13, title, description, authors, categories)
|
| 208 |
+
)
|
| 209 |
+
conn.commit()
|
| 210 |
+
return True
|
| 211 |
+
except Exception as inner_e:
|
| 212 |
+
logger.error(f"MetadataStore FTS5 insert failed: {inner_e}")
|
| 213 |
+
return False
|
| 214 |
+
logger.error(f"MetadataStore FTS5 insert failed: {e}")
|
| 215 |
+
return False
|
| 216 |
+
except Exception as e:
|
| 217 |
+
logger.error(f"MetadataStore insert_book_with_fts failed: {e}")
|
| 218 |
+
return False
|
| 219 |
+
|
| 220 |
+
def book_exists(self, isbn: str) -> bool:
|
| 221 |
+
"""Check if a book with given ISBN exists in the database."""
|
| 222 |
+
isbn = str(isbn).strip().replace(".0", "")
|
| 223 |
+
row = self._query_one(
|
| 224 |
+
"SELECT 1 FROM books WHERE isbn13 = ? OR isbn10 = ? LIMIT 1",
|
| 225 |
+
(isbn, isbn)
|
| 226 |
+
)
|
| 227 |
+
return row is not None
|
| 228 |
+
|
| 229 |
+
def get_newest_book_year(self) -> Optional[int]:
|
| 230 |
+
"""Get the publication year of the newest book in the database."""
|
| 231 |
+
conn = self.connection
|
| 232 |
+
if not conn:
|
| 233 |
+
return None
|
| 234 |
+
try:
|
| 235 |
+
cursor = conn.cursor()
|
| 236 |
+
# Try publishedDate column
|
| 237 |
+
cursor.execute(
|
| 238 |
+
"SELECT publishedDate FROM books WHERE publishedDate IS NOT NULL "
|
| 239 |
+
"ORDER BY publishedDate DESC LIMIT 1"
|
| 240 |
+
)
|
| 241 |
+
row = cursor.fetchone()
|
| 242 |
+
if row and row[0]:
|
| 243 |
+
# Extract year from date string
|
| 244 |
+
date_str = str(row[0])
|
| 245 |
+
if len(date_str) >= 4:
|
| 246 |
+
return int(date_str[:4])
|
| 247 |
+
except Exception as e:
|
| 248 |
+
logger.debug(f"get_newest_book_year failed: {e}")
|
| 249 |
+
return None
|
| 250 |
+
|
| 251 |
+
def get_book_count(self) -> int:
|
| 252 |
+
"""Get total number of books in the database."""
|
| 253 |
+
conn = self.connection
|
| 254 |
+
if not conn:
|
| 255 |
+
return 0
|
| 256 |
+
try:
|
| 257 |
+
cursor = conn.cursor()
|
| 258 |
+
cursor.execute("SELECT COUNT(*) FROM books")
|
| 259 |
+
row = cursor.fetchone()
|
| 260 |
+
return row[0] if row else 0
|
| 261 |
+
except Exception as e:
|
| 262 |
+
logger.error(f"get_book_count failed: {e}")
|
| 263 |
+
return 0
|
| 264 |
+
|
| 265 |
+
def get_books_by_year_distribution(self) -> Dict[int, int]:
|
| 266 |
+
"""Get distribution of books by publication year."""
|
| 267 |
+
conn = self.connection
|
| 268 |
+
if not conn:
|
| 269 |
+
return {}
|
| 270 |
+
try:
|
| 271 |
+
cursor = conn.cursor()
|
| 272 |
+
cursor.execute(
|
| 273 |
+
"""
|
| 274 |
+
SELECT SUBSTR(publishedDate, 1, 4) as year, COUNT(*) as count
|
| 275 |
+
FROM books
|
| 276 |
+
WHERE publishedDate IS NOT NULL AND LENGTH(publishedDate) >= 4
|
| 277 |
+
GROUP BY year
|
| 278 |
+
ORDER BY year DESC
|
| 279 |
+
LIMIT 20
|
| 280 |
+
"""
|
| 281 |
+
)
|
| 282 |
+
return {int(row[0]): row[1] for row in cursor.fetchall() if row[0].isdigit()}
|
| 283 |
+
except Exception as e:
|
| 284 |
+
logger.debug(f"get_books_by_year_distribution failed: {e}")
|
| 285 |
+
return {}
|
| 286 |
+
|
| 287 |
def load_books_processed(self): pass
|
| 288 |
def load_train_data(self): pass
|
| 289 |
|
src/core/router.py
CHANGED
|
@@ -1,74 +1,178 @@
|
|
| 1 |
import re
|
| 2 |
-
from
|
|
|
|
|
|
|
| 3 |
from src.utils import setup_logger
|
| 4 |
|
| 5 |
logger = setup_logger(__name__)
|
| 6 |
|
|
|
|
| 7 |
class QueryRouter:
|
| 8 |
"""
|
| 9 |
Intelligent Router for the RAG Pipeline.
|
| 10 |
Classifies user queries to select the optimal retrieval strategy.
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
| 12 |
Strategies:
|
| 13 |
1. EXACT (ISBN/ID) -> Pure BM25 (High Precision, No Rerank noise).
|
| 14 |
2. FAST (Keywords) -> Hybrid (RRF), No Rerank (Low Latency).
|
| 15 |
3. DEEP (Complex) -> Hybrid + Rerank (High Latency, High contextual relevance).
|
| 16 |
-
"""
|
| 17 |
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
def route(self, query: str) -> Dict[str, Any]:
|
| 23 |
"""
|
| 24 |
Analyze query and return retrieval parameters.
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
"""
|
| 27 |
cleaned_query = query.strip()
|
| 28 |
words = cleaned_query.split()
|
| 29 |
-
|
| 30 |
-
# 1.
|
| 31 |
-
# Remove hyphens/spaces for check
|
| 32 |
normalized = cleaned_query.replace("-", "").replace(" ", "")
|
| 33 |
if self.isbn_pattern.match(normalized):
|
| 34 |
-
logger.info(
|
| 35 |
-
return {"strategy": "exact", "alpha": 1.0, "rerank": False, "k_final": 5}
|
| 36 |
-
|
| 37 |
-
# 2. Check for Temporal Keywords (Freshness Bias)
|
| 38 |
-
temporal_keywords = {"new", "newest", "latest", "recent", "modern", "contemporary", "2020", "2021", "2022", "2023", "2024", "2025"}
|
| 39 |
-
is_temporal = any(word.lower() in temporal_keywords for word in words)
|
| 40 |
-
|
| 41 |
-
# 3. Check for Detail-Oriented Queries (Triggers Small-to-Big)
|
| 42 |
-
# These are queries asking about specific plot points, reactions, or hidden details
|
| 43 |
-
detail_keywords = {"twist", "ending", "spoiler", "readers", "felt", "cried", "hated", "loved",
|
| 44 |
-
"review", "opinion", "think", "unreliable", "narrator", "realize", "find out"}
|
| 45 |
-
is_detail = any(word.lower() in detail_keywords for word in words)
|
| 46 |
-
|
| 47 |
-
if is_detail:
|
| 48 |
-
logger.info(f"Router: Detected Detail Query -> SMALL_TO_BIG Strategy")
|
| 49 |
return {
|
| 50 |
-
"strategy": "
|
| 51 |
-
"
|
|
|
|
| 52 |
"k_final": 5,
|
| 53 |
-
"temporal":
|
|
|
|
|
|
|
|
|
|
| 54 |
}
|
| 55 |
-
|
| 56 |
-
#
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
|
|
|
| 1 |
import re
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import Dict, Any, List, Optional
|
| 4 |
+
|
| 5 |
from src.utils import setup_logger
|
| 6 |
|
| 7 |
logger = setup_logger(__name__)
|
| 8 |
|
| 9 |
+
|
| 10 |
class QueryRouter:
|
| 11 |
"""
|
| 12 |
Intelligent Router for the RAG Pipeline.
|
| 13 |
Classifies user queries to select the optimal retrieval strategy.
|
| 14 |
+
|
| 15 |
+
Uses model-based intent classifier when available; falls back to rule-based
|
| 16 |
+
heuristics when classifier not trained/loaded.
|
| 17 |
+
|
| 18 |
Strategies:
|
| 19 |
1. EXACT (ISBN/ID) -> Pure BM25 (High Precision, No Rerank noise).
|
| 20 |
2. FAST (Keywords) -> Hybrid (RRF), No Rerank (Low Latency).
|
| 21 |
3. DEEP (Complex) -> Hybrid + Rerank (High Latency, High contextual relevance).
|
|
|
|
| 22 |
|
| 23 |
+
Freshness-Aware Routing:
|
| 24 |
+
- Detects queries asking for "new", "latest", or specific years (2024, 2025, etc.)
|
| 25 |
+
- Sets freshness_fallback=True to enable Web Search when local results insufficient
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
# Keywords that indicate user wants fresh/recent content
|
| 29 |
+
# Note: Year numbers are detected dynamically in _detect_freshness()
|
| 30 |
+
FRESHNESS_KEYWORDS = {
|
| 31 |
+
"new", "newest", "latest", "recent", "modern", "contemporary", "current",
|
| 32 |
+
}
|
| 33 |
|
| 34 |
+
# Strong freshness indicators (always trigger fallback)
|
| 35 |
+
STRONG_FRESHNESS_KEYWORDS = {
|
| 36 |
+
"newest", "latest",
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
def __init__(self, model_dir: str | Path | None = None):
|
| 40 |
+
self.isbn_pattern = re.compile(r"^(?:\d{9}[\dX]|\d{13})$")
|
| 41 |
+
if model_dir is None:
|
| 42 |
+
from src.config import DATA_DIR
|
| 43 |
+
model_dir = DATA_DIR / "model"
|
| 44 |
+
self.model_dir = Path(model_dir)
|
| 45 |
+
self._classifier = None
|
| 46 |
+
|
| 47 |
+
def _get_classifier(self):
|
| 48 |
+
"""Lazy-load intent classifier when first needed."""
|
| 49 |
+
if self._classifier is not None:
|
| 50 |
+
return self._classifier
|
| 51 |
+
try:
|
| 52 |
+
from src.core.intent_classifier import IntentClassifier
|
| 53 |
+
clf = IntentClassifier(self.model_dir / "intent_classifier.pkl")
|
| 54 |
+
if clf.load():
|
| 55 |
+
self._classifier = clf
|
| 56 |
+
except Exception as e:
|
| 57 |
+
logger.debug("Intent classifier not available: %s", e)
|
| 58 |
+
return self._classifier
|
| 59 |
+
|
| 60 |
+
def _detect_freshness(self, words: list) -> tuple[bool, bool, Optional[int]]:
|
| 61 |
+
"""
|
| 62 |
+
Detect if query requires fresh content.
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
(is_temporal, freshness_fallback, target_year)
|
| 66 |
+
- is_temporal: Should apply temporal boost to local results
|
| 67 |
+
- freshness_fallback: Should enable Web Search if local results insufficient
|
| 68 |
+
- target_year: Specific year user is looking for (if detected)
|
| 69 |
+
"""
|
| 70 |
+
from datetime import datetime
|
| 71 |
+
current_year = datetime.now().year
|
| 72 |
+
|
| 73 |
+
lower_words = {w.lower() for w in words}
|
| 74 |
+
|
| 75 |
+
is_temporal = bool(lower_words & self.FRESHNESS_KEYWORDS)
|
| 76 |
+
freshness_fallback = bool(lower_words & self.STRONG_FRESHNESS_KEYWORDS)
|
| 77 |
+
|
| 78 |
+
# Extract explicit year from query
|
| 79 |
+
target_year = None
|
| 80 |
+
for word in words:
|
| 81 |
+
if word.isdigit() and len(word) == 4:
|
| 82 |
+
year = int(word)
|
| 83 |
+
if 2000 <= year <= 2050:
|
| 84 |
+
target_year = year
|
| 85 |
+
# Recent years (within last 3 years) trigger freshness
|
| 86 |
+
if year >= current_year - 2:
|
| 87 |
+
is_temporal = True
|
| 88 |
+
freshness_fallback = True
|
| 89 |
+
break
|
| 90 |
+
|
| 91 |
+
return is_temporal, freshness_fallback, target_year
|
| 92 |
+
|
| 93 |
+
def _route_by_rules(
|
| 94 |
+
self,
|
| 95 |
+
cleaned_query: str,
|
| 96 |
+
words: list,
|
| 97 |
+
is_temporal: bool,
|
| 98 |
+
freshness_fallback: bool = False,
|
| 99 |
+
target_year: Optional[int] = None
|
| 100 |
+
) -> Dict[str, Any]:
|
| 101 |
+
"""Fallback: rule-based routing (original logic + freshness)."""
|
| 102 |
+
detail_keywords = {
|
| 103 |
+
"twist", "ending", "spoiler", "readers", "felt", "cried", "hated", "loved",
|
| 104 |
+
"review", "opinion", "think", "unreliable", "narrator", "realize", "find out",
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
base_result = {
|
| 108 |
+
"temporal": is_temporal,
|
| 109 |
+
"freshness_fallback": freshness_fallback,
|
| 110 |
+
"freshness_threshold": 3, # Trigger web search if < 3 results
|
| 111 |
+
"target_year": target_year,
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
if any(w.lower() in detail_keywords for w in words):
|
| 115 |
+
logger.info("Router (rules): Detail Query -> SMALL_TO_BIG")
|
| 116 |
+
return {**base_result, "strategy": "small_to_big", "alpha": 0.5, "rerank": False, "k_final": 5}
|
| 117 |
+
if len(words) <= 2:
|
| 118 |
+
logger.info("Router (rules): Keyword -> FAST (Temporal=%s, Freshness=%s)", is_temporal, freshness_fallback)
|
| 119 |
+
return {**base_result, "strategy": "fast", "alpha": 0.5, "rerank": False, "k_final": 5}
|
| 120 |
+
logger.info("Router (rules): Natural Language -> DEEP (Temporal=%s, Freshness=%s)", is_temporal, freshness_fallback)
|
| 121 |
+
return {**base_result, "strategy": "deep", "alpha": 0.5, "rerank": True, "k_final": 10}
|
| 122 |
+
|
| 123 |
def route(self, query: str) -> Dict[str, Any]:
|
| 124 |
"""
|
| 125 |
Analyze query and return retrieval parameters.
|
| 126 |
+
|
| 127 |
+
Returns dict with:
|
| 128 |
+
- 'strategy': 'exact' | 'fast' | 'deep' | 'small_to_big'
|
| 129 |
+
- 'alpha': float (hybrid search weight)
|
| 130 |
+
- 'rerank': bool (use cross-encoder reranking)
|
| 131 |
+
- 'k_final': int (number of results)
|
| 132 |
+
- 'temporal': bool (apply temporal boost)
|
| 133 |
+
- 'freshness_fallback': bool (enable web search if local results insufficient)
|
| 134 |
+
- 'freshness_threshold': int (min local results before triggering web search)
|
| 135 |
+
- 'target_year': int | None (specific year user requested)
|
| 136 |
"""
|
| 137 |
cleaned_query = query.strip()
|
| 138 |
words = cleaned_query.split()
|
| 139 |
+
|
| 140 |
+
# 1. ISBN: keep regex (deterministic, correct)
|
|
|
|
| 141 |
normalized = cleaned_query.replace("-", "").replace(" ", "")
|
| 142 |
if self.isbn_pattern.match(normalized):
|
| 143 |
+
logger.info("Router: ISBN -> EXACT (%s)", normalized)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
return {
|
| 145 |
+
"strategy": "exact",
|
| 146 |
+
"alpha": 1.0,
|
| 147 |
+
"rerank": False,
|
| 148 |
"k_final": 5,
|
| 149 |
+
"temporal": False,
|
| 150 |
+
"freshness_fallback": False,
|
| 151 |
+
"freshness_threshold": 1,
|
| 152 |
+
"target_year": None,
|
| 153 |
}
|
| 154 |
+
|
| 155 |
+
# 2. Freshness detection (temporal boost + web fallback)
|
| 156 |
+
is_temporal, freshness_fallback, target_year = self._detect_freshness(words)
|
| 157 |
+
|
| 158 |
+
# 3. Model-based vs rule-based intent
|
| 159 |
+
clf = self._get_classifier()
|
| 160 |
+
if clf is not None:
|
| 161 |
+
try:
|
| 162 |
+
intent = clf.predict(cleaned_query)
|
| 163 |
+
logger.info("Router (model): %s -> %s (Freshness=%s)", intent, intent.upper(), freshness_fallback)
|
| 164 |
+
return {
|
| 165 |
+
"strategy": intent,
|
| 166 |
+
"alpha": 0.5,
|
| 167 |
+
"rerank": intent == "deep",
|
| 168 |
+
"k_final": 10 if intent == "deep" else 5,
|
| 169 |
+
"temporal": is_temporal,
|
| 170 |
+
"freshness_fallback": freshness_fallback,
|
| 171 |
+
"freshness_threshold": 3,
|
| 172 |
+
"target_year": target_year,
|
| 173 |
+
}
|
| 174 |
+
except Exception as e:
|
| 175 |
+
logger.warning("Intent classifier failed, falling back to rules: %s", e)
|
| 176 |
+
|
| 177 |
+
return self._route_by_rules(cleaned_query, words, is_temporal, freshness_fallback, target_year)
|
| 178 |
|
src/core/web_search.py
ADDED
|
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Web Search Fallback Module for Data Freshness.
|
| 3 |
+
|
| 4 |
+
Extends the existing Google Books integration in cover_fetcher.py to support
|
| 5 |
+
full book metadata retrieval and keyword-based search for new books.
|
| 6 |
+
|
| 7 |
+
This module provides:
|
| 8 |
+
- search_google_books(): Search books by keyword query
|
| 9 |
+
- fetch_book_by_isbn(): Get complete metadata for a single book
|
| 10 |
+
- is_fresh_enough(): Evaluate if local search results meet freshness requirements
|
| 11 |
+
|
| 12 |
+
API: Google Books API (free tier, no auth required for basic queries)
|
| 13 |
+
Rate Limit: ~1000 requests/day (unofficial), implement conservative caching
|
| 14 |
+
"""
|
| 15 |
+
import requests
|
| 16 |
+
from typing import Optional
|
| 17 |
+
from functools import lru_cache
|
| 18 |
+
from datetime import datetime
|
| 19 |
+
|
| 20 |
+
from src.utils import setup_logger
|
| 21 |
+
|
| 22 |
+
logger = setup_logger(__name__)
|
| 23 |
+
|
| 24 |
+
# Google Books API endpoint
|
| 25 |
+
GOOGLE_BOOKS_API = "https://www.googleapis.com/books/v1/volumes"
|
| 26 |
+
|
| 27 |
+
# Request timeout (seconds)
|
| 28 |
+
REQUEST_TIMEOUT = 5
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _parse_isbn_from_identifiers(identifiers: list[dict]) -> tuple[str, str]:
|
| 32 |
+
"""
|
| 33 |
+
Extract ISBN-13 and ISBN-10 from Google Books industryIdentifiers.
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
(isbn13, isbn10) - empty strings if not found
|
| 37 |
+
"""
|
| 38 |
+
isbn13, isbn10 = "", ""
|
| 39 |
+
for ident in identifiers:
|
| 40 |
+
id_type = ident.get("type", "")
|
| 41 |
+
id_value = ident.get("identifier", "")
|
| 42 |
+
if id_type == "ISBN_13":
|
| 43 |
+
isbn13 = id_value
|
| 44 |
+
elif id_type == "ISBN_10":
|
| 45 |
+
isbn10 = id_value
|
| 46 |
+
return isbn13, isbn10
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _parse_volume_info(volume_info: dict) -> Optional[dict]:
|
| 50 |
+
"""
|
| 51 |
+
Parse Google Books volumeInfo into our standard book format.
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
dict with keys: isbn13, isbn10, title, authors, description,
|
| 55 |
+
publishedDate, thumbnail, categories
|
| 56 |
+
None if ISBN is missing (we can't index without ISBN)
|
| 57 |
+
"""
|
| 58 |
+
identifiers = volume_info.get("industryIdentifiers", [])
|
| 59 |
+
isbn13, isbn10 = _parse_isbn_from_identifiers(identifiers)
|
| 60 |
+
|
| 61 |
+
# Skip books without ISBN - can't be indexed reliably
|
| 62 |
+
if not isbn13 and not isbn10:
|
| 63 |
+
return None
|
| 64 |
+
|
| 65 |
+
# Use isbn13 as primary, fallback to isbn10
|
| 66 |
+
primary_isbn = isbn13 if isbn13 else isbn10
|
| 67 |
+
|
| 68 |
+
# Extract image links (prefer larger sizes)
|
| 69 |
+
image_links = volume_info.get("imageLinks", {})
|
| 70 |
+
thumbnail = (
|
| 71 |
+
image_links.get("extraLarge") or
|
| 72 |
+
image_links.get("large") or
|
| 73 |
+
image_links.get("medium") or
|
| 74 |
+
image_links.get("small") or
|
| 75 |
+
image_links.get("thumbnail") or
|
| 76 |
+
""
|
| 77 |
+
)
|
| 78 |
+
# Ensure HTTPS
|
| 79 |
+
if thumbnail.startswith("http://"):
|
| 80 |
+
thumbnail = thumbnail.replace("http://", "https://")
|
| 81 |
+
|
| 82 |
+
# Categories: Google returns list, we join to single string
|
| 83 |
+
categories = volume_info.get("categories", [])
|
| 84 |
+
category_str = categories[0] if categories else "General"
|
| 85 |
+
|
| 86 |
+
return {
|
| 87 |
+
"isbn13": isbn13 or isbn10, # Primary key
|
| 88 |
+
"isbn10": isbn10 or isbn13[:10] if isbn13 else "",
|
| 89 |
+
"title": volume_info.get("title", "Unknown Title"),
|
| 90 |
+
"authors": ", ".join(volume_info.get("authors", ["Unknown"])),
|
| 91 |
+
"description": volume_info.get("description", ""),
|
| 92 |
+
"publishedDate": volume_info.get("publishedDate", ""),
|
| 93 |
+
"thumbnail": thumbnail,
|
| 94 |
+
"simple_categories": category_str,
|
| 95 |
+
"average_rating": volume_info.get("averageRating", 0.0),
|
| 96 |
+
"source": "google_books", # Track data source
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def search_google_books(query: str, max_results: int = 10) -> list[dict]:
|
| 101 |
+
"""
|
| 102 |
+
Search Google Books by keyword query.
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
query: Search query (e.g., "latest sci-fi 2024", "new fantasy novels")
|
| 106 |
+
max_results: Maximum number of results (1-40, Google's limit)
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
List of book dicts in standard format, ordered by relevance
|
| 110 |
+
"""
|
| 111 |
+
if not query or not query.strip():
|
| 112 |
+
return []
|
| 113 |
+
|
| 114 |
+
max_results = min(max_results, 40) # Google's API limit
|
| 115 |
+
|
| 116 |
+
try:
|
| 117 |
+
params = {
|
| 118 |
+
"q": query,
|
| 119 |
+
"maxResults": max_results,
|
| 120 |
+
"printType": "books",
|
| 121 |
+
"orderBy": "relevance",
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
response = requests.get(
|
| 125 |
+
GOOGLE_BOOKS_API,
|
| 126 |
+
params=params,
|
| 127 |
+
timeout=REQUEST_TIMEOUT
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
if response.status_code != 200:
|
| 131 |
+
logger.warning(f"Google Books API returned {response.status_code}")
|
| 132 |
+
return []
|
| 133 |
+
|
| 134 |
+
data = response.json()
|
| 135 |
+
total_items = data.get("totalItems", 0)
|
| 136 |
+
|
| 137 |
+
if total_items == 0:
|
| 138 |
+
logger.info(f"No results for query: {query}")
|
| 139 |
+
return []
|
| 140 |
+
|
| 141 |
+
items = data.get("items", [])
|
| 142 |
+
results = []
|
| 143 |
+
|
| 144 |
+
for item in items:
|
| 145 |
+
volume_info = item.get("volumeInfo", {})
|
| 146 |
+
parsed = _parse_volume_info(volume_info)
|
| 147 |
+
if parsed:
|
| 148 |
+
results.append(parsed)
|
| 149 |
+
|
| 150 |
+
logger.info(f"Google Books search '{query}': {len(results)} valid results")
|
| 151 |
+
return results
|
| 152 |
+
|
| 153 |
+
except requests.Timeout:
|
| 154 |
+
logger.warning(f"Google Books API timeout for query: {query}")
|
| 155 |
+
return []
|
| 156 |
+
except requests.RequestException as e:
|
| 157 |
+
logger.error(f"Google Books API request failed: {e}")
|
| 158 |
+
return []
|
| 159 |
+
except Exception as e:
|
| 160 |
+
logger.error(f"Unexpected error in search_google_books: {e}")
|
| 161 |
+
return []
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
@lru_cache(maxsize=500)
|
| 165 |
+
def fetch_book_by_isbn(isbn: str) -> Optional[dict]:
|
| 166 |
+
"""
|
| 167 |
+
Fetch complete book metadata by ISBN from Google Books.
|
| 168 |
+
|
| 169 |
+
Args:
|
| 170 |
+
isbn: ISBN-10 or ISBN-13
|
| 171 |
+
|
| 172 |
+
Returns:
|
| 173 |
+
Book dict in standard format, or None if not found
|
| 174 |
+
"""
|
| 175 |
+
if not isbn or not isbn.strip():
|
| 176 |
+
return None
|
| 177 |
+
|
| 178 |
+
isbn = isbn.strip().replace("-", "")
|
| 179 |
+
|
| 180 |
+
try:
|
| 181 |
+
params = {
|
| 182 |
+
"q": f"isbn:{isbn}",
|
| 183 |
+
"maxResults": 1,
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
response = requests.get(
|
| 187 |
+
GOOGLE_BOOKS_API,
|
| 188 |
+
params=params,
|
| 189 |
+
timeout=REQUEST_TIMEOUT
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
if response.status_code != 200:
|
| 193 |
+
return None
|
| 194 |
+
|
| 195 |
+
data = response.json()
|
| 196 |
+
if data.get("totalItems", 0) == 0:
|
| 197 |
+
return None
|
| 198 |
+
|
| 199 |
+
items = data.get("items", [])
|
| 200 |
+
if not items:
|
| 201 |
+
return None
|
| 202 |
+
|
| 203 |
+
volume_info = items[0].get("volumeInfo", {})
|
| 204 |
+
return _parse_volume_info(volume_info)
|
| 205 |
+
|
| 206 |
+
except Exception as e:
|
| 207 |
+
logger.debug(f"fetch_book_by_isbn({isbn}) failed: {e}")
|
| 208 |
+
return None
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def search_new_books_by_category(
|
| 212 |
+
category: str,
|
| 213 |
+
year: Optional[int] = None,
|
| 214 |
+
max_results: int = 10
|
| 215 |
+
) -> list[dict]:
|
| 216 |
+
"""
|
| 217 |
+
Search for recently published books in a specific category.
|
| 218 |
+
|
| 219 |
+
Args:
|
| 220 |
+
category: Book category (e.g., "fiction", "science fiction", "mystery")
|
| 221 |
+
year: Filter by publication year (default: current year)
|
| 222 |
+
max_results: Maximum number of results
|
| 223 |
+
|
| 224 |
+
Returns:
|
| 225 |
+
List of book dicts, filtered to specified year or newer
|
| 226 |
+
"""
|
| 227 |
+
if year is None:
|
| 228 |
+
year = datetime.now().year
|
| 229 |
+
|
| 230 |
+
# Build query with subject filter
|
| 231 |
+
query = f"subject:{category}"
|
| 232 |
+
|
| 233 |
+
# Get more results than needed, filter by year locally
|
| 234 |
+
raw_results = search_google_books(query, max_results=max_results * 2)
|
| 235 |
+
|
| 236 |
+
filtered = []
|
| 237 |
+
for book in raw_results:
|
| 238 |
+
pub_date = book.get("publishedDate", "")
|
| 239 |
+
if pub_date:
|
| 240 |
+
# Extract year from various date formats (YYYY, YYYY-MM, YYYY-MM-DD)
|
| 241 |
+
try:
|
| 242 |
+
pub_year = int(pub_date[:4])
|
| 243 |
+
if pub_year >= year:
|
| 244 |
+
filtered.append(book)
|
| 245 |
+
except (ValueError, IndexError):
|
| 246 |
+
continue
|
| 247 |
+
|
| 248 |
+
return filtered[:max_results]
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def is_fresh_enough(
|
| 252 |
+
results: list[dict],
|
| 253 |
+
threshold: int = 3,
|
| 254 |
+
min_year: Optional[int] = None
|
| 255 |
+
) -> bool:
|
| 256 |
+
"""
|
| 257 |
+
Evaluate if local search results meet freshness requirements.
|
| 258 |
+
|
| 259 |
+
Args:
|
| 260 |
+
results: List of book dicts from local search
|
| 261 |
+
threshold: Minimum number of results required
|
| 262 |
+
min_year: If specified, count only books published >= this year
|
| 263 |
+
|
| 264 |
+
Returns:
|
| 265 |
+
True if results are sufficient, False if web fallback should be triggered
|
| 266 |
+
"""
|
| 267 |
+
if len(results) < threshold:
|
| 268 |
+
return False
|
| 269 |
+
|
| 270 |
+
if min_year is None:
|
| 271 |
+
return True
|
| 272 |
+
|
| 273 |
+
# Count books meeting year requirement
|
| 274 |
+
fresh_count = 0
|
| 275 |
+
for book in results:
|
| 276 |
+
pub_date = book.get("publishedDate", "") or book.get("published_date", "")
|
| 277 |
+
if pub_date:
|
| 278 |
+
try:
|
| 279 |
+
pub_year = int(str(pub_date)[:4])
|
| 280 |
+
if pub_year >= min_year:
|
| 281 |
+
fresh_count += 1
|
| 282 |
+
except (ValueError, IndexError):
|
| 283 |
+
continue
|
| 284 |
+
|
| 285 |
+
return fresh_count >= threshold
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def extract_year_from_query(query: str) -> Optional[int]:
|
| 289 |
+
"""
|
| 290 |
+
Extract year requirement from user query.
|
| 291 |
+
|
| 292 |
+
Examples:
|
| 293 |
+
"books from 2024" -> 2024
|
| 294 |
+
"latest 2025 novels" -> 2025
|
| 295 |
+
"new sci-fi" -> current_year
|
| 296 |
+
"classic mystery" -> None
|
| 297 |
+
|
| 298 |
+
Returns:
|
| 299 |
+
Year as int, or None if no year requirement detected
|
| 300 |
+
"""
|
| 301 |
+
import re
|
| 302 |
+
|
| 303 |
+
# Explicit year patterns
|
| 304 |
+
year_patterns = [
|
| 305 |
+
r"\b(202[0-9])\b", # 2020-2029
|
| 306 |
+
r"\b(201[0-9])\b", # 2010-2019
|
| 307 |
+
]
|
| 308 |
+
|
| 309 |
+
for pattern in year_patterns:
|
| 310 |
+
match = re.search(pattern, query)
|
| 311 |
+
if match:
|
| 312 |
+
return int(match.group(1))
|
| 313 |
+
|
| 314 |
+
# Keywords implying "recent" = current year - 1
|
| 315 |
+
freshness_keywords = {
|
| 316 |
+
"new", "newest", "latest", "recent", "modern", "contemporary", "current"
|
| 317 |
+
}
|
| 318 |
+
|
| 319 |
+
words = set(query.lower().split())
|
| 320 |
+
if words & freshness_keywords:
|
| 321 |
+
return datetime.now().year - 1
|
| 322 |
+
|
| 323 |
+
return None
|
src/ranking/din.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
DIN (Deep Interest Network) for CTR/Ranking.
|
| 3 |
+
|
| 4 |
+
Uses attention over user behavior sequence w.r.t. target item to capture
|
| 5 |
+
user interest. Reuses SASRec item embeddings as initialization when available.
|
| 6 |
+
|
| 7 |
+
Reference: Zhou et al., "Deep Interest Network for Click-Through Rate Prediction" (KDD 2018)
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import Optional
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class DIN(nn.Module):
|
| 23 |
+
"""
|
| 24 |
+
Deep Interest Network: attention over user history w.r.t. target item.
|
| 25 |
+
|
| 26 |
+
Input:
|
| 27 |
+
- user_hist: [B, max_len] int64, padded behavior sequence (item_ids, 0=pad)
|
| 28 |
+
- target_item: [B] int64, candidate item ids
|
| 29 |
+
- aux_features: [B, num_aux] float32, optional scalar features
|
| 30 |
+
|
| 31 |
+
Output: [B] logits for click/positive probability
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(
|
| 35 |
+
self,
|
| 36 |
+
num_items: int,
|
| 37 |
+
embed_dim: int = 64,
|
| 38 |
+
max_hist_len: int = 50,
|
| 39 |
+
mlp_dims: tuple = (128, 64, 32),
|
| 40 |
+
dropout: float = 0.1,
|
| 41 |
+
num_aux: int = 0,
|
| 42 |
+
pretrained_item_emb: Optional[np.ndarray] = None,
|
| 43 |
+
):
|
| 44 |
+
super().__init__()
|
| 45 |
+
self.num_items = num_items
|
| 46 |
+
self.embed_dim = embed_dim
|
| 47 |
+
self.max_hist_len = max_hist_len
|
| 48 |
+
self.num_aux = num_aux
|
| 49 |
+
|
| 50 |
+
# Item embedding (1-indexed, 0=pad)
|
| 51 |
+
self.item_emb = nn.Embedding(num_items + 1, embed_dim, padding_idx=0)
|
| 52 |
+
|
| 53 |
+
if pretrained_item_emb is not None:
|
| 54 |
+
self._init_from_pretrained(pretrained_item_emb)
|
| 55 |
+
|
| 56 |
+
# Attention: local activation unit (DIN paper)
|
| 57 |
+
self.attn_fc = nn.Sequential(
|
| 58 |
+
nn.Linear(embed_dim * 4, 36),
|
| 59 |
+
nn.ReLU(),
|
| 60 |
+
nn.Linear(36, 1),
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
# MLP: [user_interest; target_emb; aux?] -> score
|
| 64 |
+
mlp_in = embed_dim * 2 + num_aux
|
| 65 |
+
layers = []
|
| 66 |
+
for d in mlp_dims:
|
| 67 |
+
layers.append(nn.Linear(mlp_in, d))
|
| 68 |
+
layers.append(nn.ReLU())
|
| 69 |
+
layers.append(nn.Dropout(dropout))
|
| 70 |
+
mlp_in = d
|
| 71 |
+
layers.append(nn.Linear(mlp_in, 1))
|
| 72 |
+
self.mlp = nn.Sequential(*layers)
|
| 73 |
+
|
| 74 |
+
def _init_from_pretrained(self, emb: np.ndarray) -> None:
|
| 75 |
+
"""Initialize item_emb from SASRec checkpoint."""
|
| 76 |
+
if emb.shape[0] >= self.num_items + 1 and emb.shape[1] == self.embed_dim:
|
| 77 |
+
with torch.no_grad():
|
| 78 |
+
self.item_emb.weight.data[: emb.shape[0]].copy_(torch.from_numpy(emb))
|
| 79 |
+
logger.info("DIN: Initialized item_emb from pretrained (%d x %d)", *emb.shape)
|
| 80 |
+
else:
|
| 81 |
+
logger.warning("DIN: Pretrained shape %s mismatch, skipping init", emb.shape)
|
| 82 |
+
|
| 83 |
+
def forward(
|
| 84 |
+
self,
|
| 85 |
+
user_hist: torch.Tensor,
|
| 86 |
+
target_item: torch.Tensor,
|
| 87 |
+
aux_features: Optional[torch.Tensor] = None,
|
| 88 |
+
) -> torch.Tensor:
|
| 89 |
+
"""
|
| 90 |
+
user_hist: [B, L]
|
| 91 |
+
target_item: [B]
|
| 92 |
+
aux_features: [B, num_aux] or None
|
| 93 |
+
"""
|
| 94 |
+
# [B, L, E]
|
| 95 |
+
hist_embs = self.item_emb(user_hist)
|
| 96 |
+
# [B, E]
|
| 97 |
+
target_emb = self.item_emb(target_item)
|
| 98 |
+
|
| 99 |
+
# Attention: local activation
|
| 100 |
+
# [B, L, E] -> expand target to [B, L, E]
|
| 101 |
+
target_expand = target_emb.unsqueeze(1).expand(-1, user_hist.size(1), -1)
|
| 102 |
+
attn_input = torch.cat([
|
| 103 |
+
hist_embs,
|
| 104 |
+
target_expand,
|
| 105 |
+
hist_embs * target_expand,
|
| 106 |
+
hist_embs - target_expand,
|
| 107 |
+
], dim=-1) # [B, L, 4E]
|
| 108 |
+
attn_scores = self.attn_fc(attn_input).squeeze(-1) # [B, L]
|
| 109 |
+
|
| 110 |
+
# Mask padding (0 = pad)
|
| 111 |
+
mask = (user_hist != 0).float()
|
| 112 |
+
attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
|
| 113 |
+
attn_weights = F.softmax(attn_scores, dim=1) # [B, L]
|
| 114 |
+
# When all zeros (no history), attn_weights can be nan; use mask to zero out
|
| 115 |
+
attn_weights = torch.nan_to_num(attn_weights, nan=0.0)
|
| 116 |
+
attn_weights = attn_weights * mask
|
| 117 |
+
attn_weights = attn_weights / (attn_weights.sum(dim=1, keepdim=True) + 1e-9)
|
| 118 |
+
|
| 119 |
+
# Weighted sum: user interest vector [B, E]
|
| 120 |
+
user_interest = (hist_embs * attn_weights.unsqueeze(-1)).sum(dim=1)
|
| 121 |
+
|
| 122 |
+
# MLP input
|
| 123 |
+
mlp_in = torch.cat([user_interest, target_emb], dim=1)
|
| 124 |
+
if aux_features is not None and self.num_aux > 0 and aux_features.size(-1) == self.num_aux:
|
| 125 |
+
mlp_in = torch.cat([mlp_in, aux_features], dim=1)
|
| 126 |
+
|
| 127 |
+
logits = self.mlp(mlp_in).squeeze(-1)
|
| 128 |
+
return logits
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class DINRanker:
|
| 132 |
+
"""
|
| 133 |
+
Wrapper for DIN model: load, predict, compatible with RecommendationService.
|
| 134 |
+
"""
|
| 135 |
+
|
| 136 |
+
def __init__(
|
| 137 |
+
self,
|
| 138 |
+
data_dir: str = "data/rec",
|
| 139 |
+
model_dir: str = "data/model",
|
| 140 |
+
):
|
| 141 |
+
self.data_dir = Path(data_dir)
|
| 142 |
+
self.model_dir = Path(model_dir) / "ranking"
|
| 143 |
+
self.model: Optional[DIN] = None
|
| 144 |
+
self.item_map: dict = {}
|
| 145 |
+
self.id_to_item: dict = {}
|
| 146 |
+
self.user_sequences: dict = {}
|
| 147 |
+
self.max_hist_len = 50
|
| 148 |
+
self.aux_feature_names: list = []
|
| 149 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 150 |
+
if torch.backends.mps.is_available():
|
| 151 |
+
self.device = torch.device("mps")
|
| 152 |
+
|
| 153 |
+
def load(self) -> bool:
|
| 154 |
+
"""Load trained DIN and aux data."""
|
| 155 |
+
import pickle
|
| 156 |
+
|
| 157 |
+
model_path = self.model_dir / "din_ranker.pt"
|
| 158 |
+
if not model_path.exists():
|
| 159 |
+
return False
|
| 160 |
+
|
| 161 |
+
try:
|
| 162 |
+
ckpt = torch.load(model_path, map_location=self.device, weights_only=False)
|
| 163 |
+
self.model = ckpt["model"]
|
| 164 |
+
self.model.to(self.device)
|
| 165 |
+
self.model.eval()
|
| 166 |
+
self.item_map = ckpt.get("item_map", {})
|
| 167 |
+
self.id_to_item = {v: k for k, v in self.item_map.items()}
|
| 168 |
+
self.max_hist_len = ckpt.get("max_hist_len", 50)
|
| 169 |
+
self.aux_feature_names = ckpt.get("aux_feature_names", [])
|
| 170 |
+
|
| 171 |
+
with open(self.data_dir / "user_sequences.pkl", "rb") as f:
|
| 172 |
+
seqs = pickle.load(f)
|
| 173 |
+
# user_sequences: user_id -> list of item_ids (int)
|
| 174 |
+
self.user_sequences = seqs
|
| 175 |
+
|
| 176 |
+
logger.info("DIN ranker loaded from %s", model_path)
|
| 177 |
+
return True
|
| 178 |
+
except Exception as e:
|
| 179 |
+
logger.error("Failed to load DIN ranker: %s", e)
|
| 180 |
+
return False
|
| 181 |
+
|
| 182 |
+
def predict(
|
| 183 |
+
self,
|
| 184 |
+
user_id: str,
|
| 185 |
+
candidate_items: list[str],
|
| 186 |
+
aux_features: Optional[np.ndarray] = None,
|
| 187 |
+
) -> np.ndarray:
|
| 188 |
+
"""Predict scores for (user_id, candidate_items). Returns [len(candidate_items)]."""
|
| 189 |
+
if self.model is None:
|
| 190 |
+
self.load()
|
| 191 |
+
if self.model is None:
|
| 192 |
+
return np.zeros(len(candidate_items))
|
| 193 |
+
|
| 194 |
+
hist = self.user_sequences.get(user_id, [])
|
| 195 |
+
if hist and isinstance(hist[0], str):
|
| 196 |
+
hist = [self.item_map.get(h, 0) for h in hist]
|
| 197 |
+
hist = hist[-self.max_hist_len:]
|
| 198 |
+
padded = np.zeros(self.max_hist_len, dtype=np.int64)
|
| 199 |
+
padded[: len(hist)] = hist
|
| 200 |
+
|
| 201 |
+
target_ids = np.array([self.item_map.get(str(it), 0) for it in candidate_items], dtype=np.int64)
|
| 202 |
+
|
| 203 |
+
hist_t = torch.LongTensor(padded).unsqueeze(0).expand(len(candidate_items), -1).to(self.device)
|
| 204 |
+
target_t = torch.LongTensor(target_ids).to(self.device)
|
| 205 |
+
|
| 206 |
+
aux_t = None
|
| 207 |
+
if aux_features is not None and aux_features.size > 0:
|
| 208 |
+
aux_t = torch.from_numpy(aux_features.astype(np.float32)).to(self.device)
|
| 209 |
+
|
| 210 |
+
with torch.no_grad():
|
| 211 |
+
logits = self.model(hist_t, target_t, aux_t)
|
| 212 |
+
return logits.cpu().numpy()
|
src/recommender.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
from typing import List, Dict, Any
|
| 2 |
from src.vector_db import VectorDB
|
| 3 |
from src.config import TOP_K_INITIAL, TOP_K_FINAL, DATA_DIR
|
| 4 |
from src.cache import CacheManager
|
|
@@ -136,13 +136,108 @@ class BookRecommender:
|
|
| 136 |
"emotions": emotions,
|
| 137 |
"review_highlights": highlights,
|
| 138 |
"persona_summary": "",
|
| 139 |
-
"average_rating": float(meta.get("average_rating", 0.0))
|
|
|
|
| 140 |
})
|
| 141 |
|
| 142 |
if len(results) >= TOP_K_FINAL:
|
| 143 |
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
|
| 145 |
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
|
| 147 |
def get_categories(self) -> List[str]:
|
| 148 |
"""Get unique book categories from SQLite."""
|
|
@@ -152,20 +247,47 @@ class BookRecommender:
|
|
| 152 |
"""Get available emotional tones."""
|
| 153 |
return ["All", "Happy", "Sad", "Fear", "Anger", "Surprise"]
|
| 154 |
|
| 155 |
-
def add_new_book(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
"""
|
| 157 |
-
Add a new book to the system: CSV,
|
| 158 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
"""
|
| 160 |
try:
|
| 161 |
import pandas as pd
|
| 162 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
# 1. Update Persistent Storage (CSV)
|
| 164 |
csv_path = DATA_DIR / "books_processed.csv"
|
| 165 |
|
| 166 |
# Define new row with all expected columns
|
| 167 |
new_row = {
|
| 168 |
-
"isbn13":
|
| 169 |
"title": title,
|
| 170 |
"authors": author,
|
| 171 |
"description": description,
|
|
@@ -174,13 +296,10 @@ class BookRecommender:
|
|
| 174 |
"average_rating": 0.0,
|
| 175 |
"joy": 0.0, "sadness": 0.0, "fear": 0.0, "anger": 0.0, "surprise": 0.0,
|
| 176 |
"tags": "", "review_highlights": "",
|
| 177 |
-
"isbn10":
|
|
|
|
|
|
|
| 178 |
}
|
| 179 |
-
|
| 180 |
-
isbn_s = str(isbn)
|
| 181 |
-
if metadata_store.get_book_metadata(isbn_s):
|
| 182 |
-
logger.warning(f"Book {isbn} already exists. Skipping add.")
|
| 183 |
-
return None
|
| 184 |
|
| 185 |
# Append to CSV
|
| 186 |
if csv_path.exists():
|
|
@@ -191,19 +310,20 @@ class BookRecommender:
|
|
| 191 |
# Filter/Order new_row to match CSV structure
|
| 192 |
ordered_row = {}
|
| 193 |
for col in csv_columns:
|
| 194 |
-
ordered_row[col] = new_row.get(col, "")
|
| 195 |
|
| 196 |
# Append to CSV
|
| 197 |
pd.DataFrame([ordered_row]).to_csv(csv_path, mode='a', header=False, index=False)
|
| 198 |
else:
|
| 199 |
-
|
| 200 |
|
| 201 |
new_row["large_thumbnail"] = new_row["thumbnail"]
|
|
|
|
| 202 |
|
| 203 |
-
#
|
| 204 |
-
metadata_store.
|
| 205 |
|
| 206 |
-
#
|
| 207 |
self.vector_db.add_book(new_row)
|
| 208 |
|
| 209 |
logger.info(f"Successfully added book {isbn}: {title}")
|
|
|
|
| 1 |
+
from typing import List, Dict, Any, Optional
|
| 2 |
from src.vector_db import VectorDB
|
| 3 |
from src.config import TOP_K_INITIAL, TOP_K_FINAL, DATA_DIR
|
| 4 |
from src.cache import CacheManager
|
|
|
|
| 136 |
"emotions": emotions,
|
| 137 |
"review_highlights": highlights,
|
| 138 |
"persona_summary": "",
|
| 139 |
+
"average_rating": float(meta.get("average_rating", 0.0)),
|
| 140 |
+
"source": "local", # Track data source
|
| 141 |
})
|
| 142 |
|
| 143 |
if len(results) >= TOP_K_FINAL:
|
| 144 |
break
|
| 145 |
+
|
| 146 |
+
# 3. Web Search Fallback (Freshness-Aware)
|
| 147 |
+
# Triggered when: freshness_fallback=True AND local results < threshold
|
| 148 |
+
if decision.get("freshness_fallback", False):
|
| 149 |
+
threshold = decision.get("freshness_threshold", 3)
|
| 150 |
+
if len(results) < threshold:
|
| 151 |
+
web_results = self._fetch_from_web(query, TOP_K_FINAL - len(results), category)
|
| 152 |
+
results.extend(web_results)
|
| 153 |
+
logger.info(f"Web fallback added {len(web_results)} books")
|
| 154 |
+
|
| 155 |
+
# Cache the results
|
| 156 |
+
if results:
|
| 157 |
+
self.cache.set(cache_key, results)
|
| 158 |
|
| 159 |
return results
|
| 160 |
+
|
| 161 |
+
def _fetch_from_web(
|
| 162 |
+
self,
|
| 163 |
+
query: str,
|
| 164 |
+
max_results: int,
|
| 165 |
+
category: str = "All"
|
| 166 |
+
) -> List[Dict[str, Any]]:
|
| 167 |
+
"""
|
| 168 |
+
Fetch books from Google Books API when local results are insufficient.
|
| 169 |
+
Auto-persists discovered books to local database for future queries.
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
query: User's search query
|
| 173 |
+
max_results: Maximum number of results to fetch
|
| 174 |
+
category: Category filter (not applied to web search, used for filtering results)
|
| 175 |
+
|
| 176 |
+
Returns:
|
| 177 |
+
List of formatted book dicts ready for response
|
| 178 |
+
"""
|
| 179 |
+
try:
|
| 180 |
+
from src.core.web_search import search_google_books
|
| 181 |
+
except ImportError:
|
| 182 |
+
logger.warning("Web search module not available")
|
| 183 |
+
return []
|
| 184 |
+
|
| 185 |
+
results = []
|
| 186 |
+
|
| 187 |
+
try:
|
| 188 |
+
web_books = search_google_books(query, max_results=max_results * 2)
|
| 189 |
+
|
| 190 |
+
for book in web_books:
|
| 191 |
+
isbn = book.get("isbn13", "")
|
| 192 |
+
if not isbn:
|
| 193 |
+
continue
|
| 194 |
+
|
| 195 |
+
# Skip if already in local database
|
| 196 |
+
if metadata_store.book_exists(isbn):
|
| 197 |
+
continue
|
| 198 |
+
|
| 199 |
+
# Category filter (if specified)
|
| 200 |
+
if category and category != "All":
|
| 201 |
+
book_cat = book.get("simple_categories", "")
|
| 202 |
+
if category.lower() not in book_cat.lower():
|
| 203 |
+
continue
|
| 204 |
+
|
| 205 |
+
# Auto-persist to local database
|
| 206 |
+
added = self.add_new_book(
|
| 207 |
+
isbn=isbn,
|
| 208 |
+
title=book.get("title", ""),
|
| 209 |
+
author=book.get("authors", "Unknown"),
|
| 210 |
+
description=book.get("description", ""),
|
| 211 |
+
category=book.get("simple_categories", "General"),
|
| 212 |
+
thumbnail=book.get("thumbnail"),
|
| 213 |
+
published_date=book.get("publishedDate", ""),
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
if added:
|
| 217 |
+
results.append({
|
| 218 |
+
"isbn": isbn,
|
| 219 |
+
"title": book.get("title", ""),
|
| 220 |
+
"authors": book.get("authors", "Unknown"),
|
| 221 |
+
"description": book.get("description", ""),
|
| 222 |
+
"thumbnail": book.get("thumbnail", ""),
|
| 223 |
+
"caption": f"{book.get('title', '')} by {book.get('authors', 'Unknown')}",
|
| 224 |
+
"tags": [],
|
| 225 |
+
"emotions": {"joy": 0.0, "sadness": 0.0, "fear": 0.0, "anger": 0.0, "surprise": 0.0},
|
| 226 |
+
"review_highlights": [],
|
| 227 |
+
"persona_summary": "",
|
| 228 |
+
"average_rating": float(book.get("average_rating", 0.0)),
|
| 229 |
+
"source": "google_books", # Track data source
|
| 230 |
+
})
|
| 231 |
+
|
| 232 |
+
if len(results) >= max_results:
|
| 233 |
+
break
|
| 234 |
+
|
| 235 |
+
logger.info(f"Web fallback: Found and persisted {len(results)} new books")
|
| 236 |
+
return results
|
| 237 |
+
|
| 238 |
+
except Exception as e:
|
| 239 |
+
logger.error(f"Web fallback failed: {e}")
|
| 240 |
+
return []
|
| 241 |
|
| 242 |
def get_categories(self) -> List[str]:
|
| 243 |
"""Get unique book categories from SQLite."""
|
|
|
|
| 247 |
"""Get available emotional tones."""
|
| 248 |
return ["All", "Happy", "Sad", "Fear", "Anger", "Surprise"]
|
| 249 |
|
| 250 |
+
def add_new_book(
|
| 251 |
+
self,
|
| 252 |
+
isbn: str,
|
| 253 |
+
title: str,
|
| 254 |
+
author: str,
|
| 255 |
+
description: str,
|
| 256 |
+
category: str = "General",
|
| 257 |
+
thumbnail: Optional[str] = None,
|
| 258 |
+
published_date: Optional[str] = None,
|
| 259 |
+
) -> Optional[Dict[str, Any]]:
|
| 260 |
"""
|
| 261 |
+
Add a new book to the system: CSV, SQLite (with FTS5), and ChromaDB.
|
| 262 |
+
|
| 263 |
+
Args:
|
| 264 |
+
isbn: ISBN-13 or ISBN-10
|
| 265 |
+
title: Book title
|
| 266 |
+
author: Author name(s)
|
| 267 |
+
description: Book description
|
| 268 |
+
category: Book category
|
| 269 |
+
thumbnail: Cover image URL
|
| 270 |
+
published_date: Publication date (YYYY, YYYY-MM, or YYYY-MM-DD)
|
| 271 |
+
|
| 272 |
+
Returns:
|
| 273 |
+
New book dictionary if successful, None otherwise
|
| 274 |
"""
|
| 275 |
try:
|
| 276 |
import pandas as pd
|
| 277 |
|
| 278 |
+
isbn_s = str(isbn).strip()
|
| 279 |
+
|
| 280 |
+
# Check if already exists
|
| 281 |
+
if metadata_store.book_exists(isbn_s):
|
| 282 |
+
logger.debug(f"Book {isbn} already exists. Skipping add.")
|
| 283 |
+
return None
|
| 284 |
+
|
| 285 |
# 1. Update Persistent Storage (CSV)
|
| 286 |
csv_path = DATA_DIR / "books_processed.csv"
|
| 287 |
|
| 288 |
# Define new row with all expected columns
|
| 289 |
new_row = {
|
| 290 |
+
"isbn13": isbn_s,
|
| 291 |
"title": title,
|
| 292 |
"authors": author,
|
| 293 |
"description": description,
|
|
|
|
| 296 |
"average_rating": 0.0,
|
| 297 |
"joy": 0.0, "sadness": 0.0, "fear": 0.0, "anger": 0.0, "surprise": 0.0,
|
| 298 |
"tags": "", "review_highlights": "",
|
| 299 |
+
"isbn10": isbn_s[:10] if len(isbn_s) >= 10 else isbn_s,
|
| 300 |
+
"publishedDate": published_date or "",
|
| 301 |
+
"source": "google_books", # Track data source
|
| 302 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
|
| 304 |
# Append to CSV
|
| 305 |
if csv_path.exists():
|
|
|
|
| 310 |
# Filter/Order new_row to match CSV structure
|
| 311 |
ordered_row = {}
|
| 312 |
for col in csv_columns:
|
| 313 |
+
ordered_row[col] = new_row.get(col, "")
|
| 314 |
|
| 315 |
# Append to CSV
|
| 316 |
pd.DataFrame([ordered_row]).to_csv(csv_path, mode='a', header=False, index=False)
|
| 317 |
else:
|
| 318 |
+
pd.DataFrame([new_row]).to_csv(csv_path, index=False)
|
| 319 |
|
| 320 |
new_row["large_thumbnail"] = new_row["thumbnail"]
|
| 321 |
+
new_row["image"] = new_row["thumbnail"]
|
| 322 |
|
| 323 |
+
# 2. Insert into SQLite with FTS5 (incremental indexing)
|
| 324 |
+
metadata_store.insert_book_with_fts(new_row)
|
| 325 |
|
| 326 |
+
# 3. Update Vector DB (ChromaDB)
|
| 327 |
self.vector_db.add_book(new_row)
|
| 328 |
|
| 329 |
logger.info(f"Successfully added book {isbn}: {title}")
|
src/services/recommend_service.py
CHANGED
|
@@ -7,10 +7,12 @@ from pathlib import Path
|
|
| 7 |
from src.recall.fusion import RecallFusion
|
| 8 |
from src.ranking.features import FeatureEngineer
|
| 9 |
from src.ranking.explainer import RankingExplainer
|
|
|
|
| 10 |
from src.utils import setup_logger
|
| 11 |
|
| 12 |
logger = setup_logger(__name__)
|
| 13 |
|
|
|
|
| 14 |
class RecommendationService:
|
| 15 |
def __init__(self, data_dir='data/rec', model_dir='data/model'):
|
| 16 |
self.data_dir = Path(data_dir)
|
|
@@ -21,6 +23,8 @@ class RecommendationService:
|
|
| 21 |
|
| 22 |
self.ranker = None
|
| 23 |
self.ranker_loaded = False
|
|
|
|
|
|
|
| 24 |
self.xgb_ranker = None
|
| 25 |
self.meta_model = None
|
| 26 |
self.use_stacking = False
|
|
@@ -34,7 +38,14 @@ class RecommendationService:
|
|
| 34 |
self.fusion.load_models()
|
| 35 |
self.fe.load_base_data()
|
| 36 |
|
| 37 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
ranker_path = self.model_dir / 'ranking/lgbm_ranker.txt'
|
| 39 |
if ranker_path.exists():
|
| 40 |
self.ranker = lgb.Booster(model_file=str(ranker_path))
|
|
@@ -119,29 +130,34 @@ class RecommendationService:
|
|
| 119 |
candidate_items = [item for item, score in candidates]
|
| 120 |
|
| 121 |
# 2. Ranking
|
| 122 |
-
if
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
if not valid_candidates:
|
| 127 |
-
return []
|
| 128 |
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
X_df = self.fe.generate_features_batch(user_id, valid_candidates)
|
| 131 |
-
|
| 132 |
-
# Align features to match model
|
| 133 |
model_features = self.ranker.feature_name()
|
| 134 |
for col in model_features:
|
| 135 |
if col not in X_df.columns:
|
| 136 |
X_df[col] = 0
|
| 137 |
X_df = X_df[model_features]
|
| 138 |
|
| 139 |
-
# Predict
|
| 140 |
if self.use_stacking and self.xgb_ranker is not None and self.meta_model is not None:
|
| 141 |
-
# Stacking: Level-1 predictions -> Level-2 meta-learner
|
| 142 |
lgb_scores = self.ranker.predict(X_df)
|
| 143 |
-
|
| 144 |
-
# Check if XGB Ranker is a raw Booster or Sklearn Estimator
|
| 145 |
if isinstance(self.xgb_ranker, xgb.Booster):
|
| 146 |
dtest = xgb.DMatrix(X_df)
|
| 147 |
xgb_scores = self.xgb_ranker.predict(dtest)
|
|
@@ -150,24 +166,19 @@ class RecommendationService:
|
|
| 150 |
meta_features = np.column_stack([lgb_scores, xgb_scores])
|
| 151 |
scores = self.meta_model.predict_proba(meta_features)[:, 1]
|
| 152 |
else:
|
| 153 |
-
# Fallback: LightGBM only (backward compatible)
|
| 154 |
scores = self.ranker.predict(X_df)
|
| 155 |
|
| 156 |
-
# Compute SHAP explanations (V2.7)
|
| 157 |
explanations_list = []
|
| 158 |
if self.explainer is not None:
|
| 159 |
try:
|
| 160 |
explanations_list = self.explainer.explain(X_df, top_k=3)
|
| 161 |
except Exception as e:
|
| 162 |
-
logger.warning(f"SHAP explanation failed: {e}")
|
| 163 |
explanations_list = [[] for _ in valid_candidates]
|
| 164 |
else:
|
| 165 |
explanations_list = [[] for _ in valid_candidates]
|
| 166 |
|
| 167 |
-
# Combine with explanations
|
| 168 |
final_scores = list(zip(valid_candidates, scores, explanations_list))
|
| 169 |
final_scores.sort(key=lambda x: x[1], reverse=True)
|
| 170 |
-
|
| 171 |
else:
|
| 172 |
# Fallback to recall scores, but filter
|
| 173 |
final_scores = []
|
|
|
|
| 7 |
from src.recall.fusion import RecallFusion
|
| 8 |
from src.ranking.features import FeatureEngineer
|
| 9 |
from src.ranking.explainer import RankingExplainer
|
| 10 |
+
from src.ranking.din import DINRanker
|
| 11 |
from src.utils import setup_logger
|
| 12 |
|
| 13 |
logger = setup_logger(__name__)
|
| 14 |
|
| 15 |
+
|
| 16 |
class RecommendationService:
|
| 17 |
def __init__(self, data_dir='data/rec', model_dir='data/model'):
|
| 18 |
self.data_dir = Path(data_dir)
|
|
|
|
| 23 |
|
| 24 |
self.ranker = None
|
| 25 |
self.ranker_loaded = False
|
| 26 |
+
self.din_ranker = DINRanker(str(data_dir), str(model_dir))
|
| 27 |
+
self.din_ranker_loaded = False
|
| 28 |
self.xgb_ranker = None
|
| 29 |
self.meta_model = None
|
| 30 |
self.use_stacking = False
|
|
|
|
| 38 |
self.fusion.load_models()
|
| 39 |
self.fe.load_base_data()
|
| 40 |
|
| 41 |
+
# Prefer DIN ranker when available (deep model)
|
| 42 |
+
din_path = self.model_dir / 'ranking/din_ranker.pt'
|
| 43 |
+
if din_path.exists():
|
| 44 |
+
if self.din_ranker.load():
|
| 45 |
+
self.din_ranker_loaded = True
|
| 46 |
+
logger.info("DIN ranker loaded — using deep model for ranking")
|
| 47 |
+
|
| 48 |
+
# Load LGBM ranker (fallback when DIN not available)
|
| 49 |
ranker_path = self.model_dir / 'ranking/lgbm_ranker.txt'
|
| 50 |
if ranker_path.exists():
|
| 51 |
self.ranker = lgb.Booster(model_file=str(ranker_path))
|
|
|
|
| 130 |
candidate_items = [item for item, score in candidates]
|
| 131 |
|
| 132 |
# 2. Ranking
|
| 133 |
+
valid_candidates = [item for item in candidate_items if item not in fav_isbns]
|
| 134 |
+
if not valid_candidates:
|
| 135 |
+
return []
|
|
|
|
|
|
|
|
|
|
| 136 |
|
| 137 |
+
if self.din_ranker_loaded:
|
| 138 |
+
# DIN: deep model; optional aux features from FeatureEngineer
|
| 139 |
+
aux_arr = None
|
| 140 |
+
if self.din_ranker.aux_feature_names:
|
| 141 |
+
X_df = self.fe.generate_features_batch(user_id, valid_candidates)
|
| 142 |
+
for col in self.din_ranker.aux_feature_names:
|
| 143 |
+
if col not in X_df.columns:
|
| 144 |
+
X_df[col] = 0
|
| 145 |
+
aux_arr = X_df[self.din_ranker.aux_feature_names].values.astype(np.float32)
|
| 146 |
+
scores = self.din_ranker.predict(user_id, valid_candidates, aux_arr)
|
| 147 |
+
explanations_list = [[] for _ in valid_candidates]
|
| 148 |
+
final_scores = list(zip(valid_candidates, scores, explanations_list))
|
| 149 |
+
final_scores.sort(key=lambda x: x[1], reverse=True)
|
| 150 |
+
elif self.ranker_loaded:
|
| 151 |
+
# LGBM / stacking path
|
| 152 |
X_df = self.fe.generate_features_batch(user_id, valid_candidates)
|
|
|
|
|
|
|
| 153 |
model_features = self.ranker.feature_name()
|
| 154 |
for col in model_features:
|
| 155 |
if col not in X_df.columns:
|
| 156 |
X_df[col] = 0
|
| 157 |
X_df = X_df[model_features]
|
| 158 |
|
|
|
|
| 159 |
if self.use_stacking and self.xgb_ranker is not None and self.meta_model is not None:
|
|
|
|
| 160 |
lgb_scores = self.ranker.predict(X_df)
|
|
|
|
|
|
|
| 161 |
if isinstance(self.xgb_ranker, xgb.Booster):
|
| 162 |
dtest = xgb.DMatrix(X_df)
|
| 163 |
xgb_scores = self.xgb_ranker.predict(dtest)
|
|
|
|
| 166 |
meta_features = np.column_stack([lgb_scores, xgb_scores])
|
| 167 |
scores = self.meta_model.predict_proba(meta_features)[:, 1]
|
| 168 |
else:
|
|
|
|
| 169 |
scores = self.ranker.predict(X_df)
|
| 170 |
|
|
|
|
| 171 |
explanations_list = []
|
| 172 |
if self.explainer is not None:
|
| 173 |
try:
|
| 174 |
explanations_list = self.explainer.explain(X_df, top_k=3)
|
| 175 |
except Exception as e:
|
|
|
|
| 176 |
explanations_list = [[] for _ in valid_candidates]
|
| 177 |
else:
|
| 178 |
explanations_list = [[] for _ in valid_candidates]
|
| 179 |
|
|
|
|
| 180 |
final_scores = list(zip(valid_candidates, scores, explanations_list))
|
| 181 |
final_scores.sort(key=lambda x: x[1], reverse=True)
|
|
|
|
| 182 |
else:
|
| 183 |
# Fallback to recall scores, but filter
|
| 184 |
final_scores = []
|
src/vector_db.py
CHANGED
|
@@ -321,7 +321,14 @@ class VectorDB:
|
|
| 321 |
|
| 322 |
def add_book(self, book_data: dict):
|
| 323 |
"""
|
| 324 |
-
Dynamically add a new book to the vector
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 325 |
"""
|
| 326 |
from langchain_core.documents import Document
|
| 327 |
|
|
@@ -330,7 +337,7 @@ class VectorDB:
|
|
| 330 |
author = book_data.get("authors", "")
|
| 331 |
description = book_data.get("description", "")
|
| 332 |
|
| 333 |
-
#
|
| 334 |
content = f"Title: {title}\nAuthor: {author}\nDescription: {description}\nISBN: {isbn}"
|
| 335 |
doc = Document(
|
| 336 |
page_content=content,
|
|
@@ -346,7 +353,4 @@ class VectorDB:
|
|
| 346 |
if self.db:
|
| 347 |
self.db.add_documents([doc])
|
| 348 |
logger.info(f"Added book {isbn} to ChromaDB")
|
| 349 |
-
|
| 350 |
-
if hasattr(self, 'fts_enabled') and self.fts_enabled:
|
| 351 |
-
logger.info("Note: FTS5 database updates are not implemented in add_book yet.")
|
| 352 |
|
|
|
|
| 321 |
|
| 322 |
def add_book(self, book_data: dict):
|
| 323 |
"""
|
| 324 |
+
Dynamically add a new book to the ChromaDB vector index.
|
| 325 |
+
|
| 326 |
+
Note: FTS5 incremental updates are handled separately via
|
| 327 |
+
metadata_store.insert_book_with_fts() called from BookRecommender.add_new_book().
|
| 328 |
+
This method only handles the dense vector index.
|
| 329 |
+
|
| 330 |
+
Args:
|
| 331 |
+
book_data: Dict with isbn13, title, authors, description, etc.
|
| 332 |
"""
|
| 333 |
from langchain_core.documents import Document
|
| 334 |
|
|
|
|
| 337 |
author = book_data.get("authors", "")
|
| 338 |
description = book_data.get("description", "")
|
| 339 |
|
| 340 |
+
# Add to ChromaDB (dense vector index)
|
| 341 |
content = f"Title: {title}\nAuthor: {author}\nDescription: {description}\nISBN: {isbn}"
|
| 342 |
doc = Document(
|
| 343 |
page_content=content,
|
|
|
|
| 353 |
if self.db:
|
| 354 |
self.db.add_documents([doc])
|
| 355 |
logger.info(f"Added book {isbn} to ChromaDB")
|
|
|
|
|
|
|
|
|
|
| 356 |
|