Spaces:
Running
Running
| # main.py | |
| from fastapi import FastAPI, HTTPException, status, File, UploadFile, Form, Query | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from typing import Optional | |
| import pandas as pd | |
| import io | |
| import os | |
| from text_engine import Text_Search_Engine | |
| app = FastAPI(title="CortexSearch", version="1.0", description="A flexible text search API with multiple FAISS index types and BM25 support.") | |
| # Choose default index_type here: "flat", "ivf", or "hnsw" | |
| store = Text_Search_Engine(index_type=os.getenv("INDEX_TYPE", "flat")) | |
| try: | |
| store.load() | |
| except Exception: | |
| pass | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def root(): | |
| return {"Status": "The CortexSearch API is live!!!"} | |
| # ------------------------- | |
| # Column preview endpoint | |
| # ------------------------- | |
| async def list_columns(file: UploadFile = File(...)): | |
| """ | |
| Upload a CSV and get available columns back. | |
| Useful to preview before choosing columns to index. | |
| """ | |
| try: | |
| contents = await file.read() | |
| df = pd.read_csv(io.BytesIO(contents)) | |
| return {"available_columns": list(df.columns)} | |
| except Exception as e: | |
| raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) | |
| # ------------------------- | |
| # Health check endpoint | |
| # ------------------------- | |
| async def health(): | |
| return {"status": "ok", "rows_indexed": len(store.rows), "index_type": store.index_type} | |
| # ------------------------- | |
| # Upload CSV (build fresh index) | |
| # ------------------------- | |
| async def upload_csv(file: UploadFile = File(...), columns: str = Form(...), index_type: Optional[str] = Form(None)): | |
| #Upload CSV and specify columns (comma-separated) to combine into searchable text. | |
| #Optional form field 'index_type' can be 'flat', 'ivf', or 'hnsw' to override engine default. | |
| try: | |
| contents = await file.read() | |
| df = pd.read_csv(io.BytesIO(contents)) | |
| column_list = [c.strip() for c in columns.split(",") if c.strip()] | |
| # Validate | |
| for col in column_list: | |
| if col not in df.columns: | |
| return { | |
| "status": "error", | |
| "detail": f"Column '{col}' not found.", | |
| "available_columns": list(df.columns), | |
| } | |
| rows = df.dropna(subset=column_list).to_dict(orient="records") | |
| for r in rows: | |
| r["_search_text"] = " ".join(str(r[col]) for col in column_list if r.get(col) is not None) | |
| texts = [r["_search_text"] for r in rows] | |
| if index_type: | |
| store.index_type = index_type | |
| store.encode_store(rows, texts) | |
| return {"status": "success", "count": len(rows), "used_columns": column_list, "index_type": store.index_type} | |
| except Exception as e: | |
| raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)) | |
| # ------------------------- | |
| # Add CSV (append new rows) | |
| # ------------------------- | |
| async def add_csv(file: UploadFile = File(...), columns: str = Form(...)): | |
| try: | |
| contents = await file.read() | |
| df = pd.read_csv(io.BytesIO(contents)) | |
| column_list = [c.strip() for c in columns.split(",") if c.strip()] | |
| for col in column_list: | |
| if col not in df.columns: | |
| return { | |
| "status": "error", | |
| "detail": f"Column '{col}' not found.", | |
| "available_columns": list(df.columns), | |
| } | |
| new_rows = df.dropna(subset=column_list).to_dict(orient="records") | |
| for r in new_rows: | |
| r["_search_text"] = " ".join(str(r[col]) for col in column_list if r.get(col) is not None) | |
| new_texts = [r["_search_text"] for r in new_rows] | |
| store.add_rows(new_rows, new_texts) | |
| return {"status": "success", "added_count": len(new_rows), "used_columns": column_list, "total_rows": len(store.rows)} | |
| except Exception as e: | |
| raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)) | |
| # ------------------------- | |
| # Search endpoint | |
| # ------------------------- | |
| async def search( | |
| query: str, | |
| top_k: int = 3, | |
| mode: str = Query("semantic", enum=["semantic", "lexical", "hybrid"]), | |
| alpha: float = 0.5,): | |
| #mode: semantic | lexical | hybrid | |
| #alpha: weight for semantic in hybrid (0..1) | |
| try: | |
| if mode == "semantic": | |
| results = store.search(query, top_k=top_k) | |
| elif mode == "lexical": | |
| if store.bm25 is None: | |
| return {"results": []} | |
| tokenized_query = query.lower().split() | |
| scores = store.bm25.get_scores(tokenized_query) | |
| ranked = sorted(enumerate(scores), key=lambda x: x[1], reverse=True)[:top_k] | |
| results = [{**store.rows[i], "score": float(score)} for i, score in ranked] | |
| else: | |
| results = store.hybrid_search(query, top_k=top_k, alpha=alpha) | |
| return {"results": results} | |
| except Exception as e: | |
| raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)) | |
| # ------------------------- | |
| # Delete all data | |
| # ------------------------- | |
| async def delete_data(): | |
| try: | |
| store.clear_vdb() | |
| return {"status": "success", "message": "Vector DB cleared"} | |
| except Exception as e: | |
| raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)) |