# text_engine.py import os import pickle import logging from typing import List, Optional import numpy as np from sentence_transformers import SentenceTransformer import faiss from rank_bm25 import BM25Okapi logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class Text_Search_Engine: def __init__( self, base_folder: str = "vector_store", model_name: str = "sentence-transformers/LaBSE", index_type: str = "flat", ): self.base_folder = base_folder self.embeddings_folder = os.path.join(base_folder, "embeddings") self.docs_folder = os.path.join(base_folder, "documents") os.makedirs(self.embeddings_folder, exist_ok=True) os.makedirs(self.docs_folder, exist_ok=True) self.model = SentenceTransformer(model_name) self.index: Optional[faiss.Index] = None self.rows: List[dict] = [] self.texts: List[str] = [] self.bm25: Optional[BM25Okapi] = None self.index_type = index_type # ------------------------- # Index creation utilities # ------------------------- def _create_index(self, dimension: int, embeddings: np.ndarray): if self.index_type == "flat": self.index = faiss.IndexFlatL2(dimension) elif self.index_type == "ivf": nlist = max(1, min(256, len(embeddings) // 10)) quantizer = faiss.IndexFlatL2(dimension) self.index = faiss.IndexIVFFlat(quantizer, dimension, nlist, faiss.METRIC_L2) self.index.train(np.array(embeddings).astype("float32")) elif self.index_type == "hnsw": self.index = faiss.IndexHNSWFlat(dimension, 32) else: raise ValueError(f"Unsupported index type: {self.index_type}") def _persist(self): try: if self.index is not None: faiss.write_index(self.index, os.path.join(self.embeddings_folder, "multilingual.index")) with open(os.path.join(self.docs_folder, "rows.pkl"), "wb") as f: pickle.dump(self.rows, f) logger.info("Persisted index and rows to disk.") except Exception as e: logger.exception("Failed to persist index/rows: %s", e) # ------------------------- # Core operations # ------------------------- def encode_store(self, rows: List[dict], texts: List[str]): try: embeddings = self.model.encode(texts, convert_to_numpy=True) dimension = embeddings.shape[1] self._create_index(dimension, embeddings) self.index.add(np.array(embeddings).astype("float32")) self.rows = rows self.texts = texts tokenized_corpus = [t.lower().split() for t in texts] self.bm25 = BM25Okapi(tokenized_corpus) self._persist() logger.info("Index built with %d rows (index_type=%s).", len(rows), self.index_type) except Exception as e: logger.exception("Error in encode_store: %s", e) raise def load(self): try: index_path = os.path.join(self.embeddings_folder, "multilingual.index") rows_path = os.path.join(self.docs_folder, "rows.pkl") if os.path.exists(index_path) and os.path.exists(rows_path): self.index = faiss.read_index(index_path) with open(rows_path, "rb") as f: self.rows = pickle.load(f) self.texts = [r["_search_text"] for r in self.rows] tokenized_corpus = [t.lower().split() for t in self.texts] self.bm25 = BM25Okapi(tokenized_corpus) logger.info("Loaded index and %d rows from disk.", len(self.rows)) else: logger.info("No persisted index/rows found.") except Exception as e: logger.exception("Error in load: %s", e) raise def add_rows(self, new_rows: List[dict], new_texts: List[str]): try: if not new_rows: return new_embeddings = self.model.encode(new_texts, convert_to_numpy=True).astype("float32") if self.index is None: self._create_index(new_embeddings.shape[1], new_embeddings) self.index.add(new_embeddings) else: if isinstance(self.index, faiss.IndexIVFFlat) and not self.index.is_trained: combined = np.vstack([self.model.encode(self.texts, convert_to_numpy=True).astype("float32"), new_embeddings]) if self.texts else new_embeddings self.index.train(combined) self.index.add(new_embeddings) self.rows.extend(new_rows) self.texts.extend(new_texts) tokenized_corpus = [t.lower().split() for t in self.texts] self.bm25 = BM25Okapi(tokenized_corpus) self._persist() logger.info("Added %d new rows. Total rows: %d", len(new_rows), len(self.rows)) except Exception as e: logger.exception("Error in add_rows: %s", e) raise # ------------------------- # Search methods # ------------------------- def search(self, query: str, top_k: int = 3): try: if self.index is None: return [] query_emb = self.model.encode([query], convert_to_numpy=True).astype("float32") k = min(top_k, len(self.rows)) distances, indices = self.index.search(query_emb, k=k) results = [ {**self.rows[i], "distance": float(distances[0][j])} for j, i in enumerate(indices[0]) ] return sorted(results, key=lambda x: x["distance"]) except Exception as e: logger.exception("Error in search: %s", e) return [] def hybrid_search(self, query: str, top_k: int = 3, alpha: float = 0.5): try: if self.index is None or self.bm25 is None: return [] # 🔹 Step 1: Encode query query_emb = self.model.encode([query], convert_to_numpy=True).astype("float32") # 🔹 Step 2: Retrieve top candidates (IMPORTANT) retrieve_k = min(20, len(self.texts)) # candidate pool distances, indices = self.index.search(query_emb, k=retrieve_k) candidate_ids = indices[0] # 🔹 Step 3: Semantic scores (convert distance → similarity) sem_scores = {} for j, i in enumerate(candidate_ids): sim = 1 / (1 + distances[0][j]) sem_scores[i] = sim # 🔹 Step 4: BM25 scores (only for candidates) tokenized_query = query.lower().split() bm25_scores = self.bm25.get_scores(tokenized_query) lex_scores = {i: bm25_scores[i] for i in candidate_ids} # 🔹 Step 5: NORMALIZATION (CRITICAL) def normalize(scores_dict): vals = list(scores_dict.values()) if not vals: return scores_dict min_v, max_v = min(vals), max(vals) if max_v - min_v == 0: return {k: 0.0 for k in scores_dict} return {k: (v - min_v) / (max_v - min_v) for k, v in scores_dict.items()} sem_scores = normalize(sem_scores) lex_scores = normalize(lex_scores) # 🔹 Step 6: Combine scores combined = [] for i in candidate_ids: sem = sem_scores.get(i, 0.0) lex = lex_scores.get(i, 0.0) score = alpha * sem + (1 - alpha) * lex combined.append({**self.rows[i], "score": float(score)}) # 🔹 Step 7: Sort and return combined = sorted(combined, key=lambda x: x["score"], reverse=True) return combined[:top_k] except Exception as e: logger.exception("Error in hybrid_search: %s", e) return [] # ------------------------- # Utilities # ------------------------- def clear_vdb(self): try: if self.index is not None: try: self.index.reset() except Exception: self.index = None self.rows = [] self.texts = [] self.bm25 = None index_path = os.path.join(self.embeddings_folder, "multilingual.index") docs_path = os.path.join(self.docs_folder, "rows.pkl") if os.path.exists(index_path): os.remove(index_path) if os.path.exists(docs_path): os.remove(docs_path) logger.info("Cleared vector DB and persisted files.") except Exception as e: logger.exception("Error in clear_vdb: %s", e) raise