Spaces:
Running
Running
File size: 5,660 Bytes
e4b6894 c46b826 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 | # main.py
from fastapi import FastAPI, HTTPException, status, File, UploadFile, Form, Query
from fastapi.middleware.cors import CORSMiddleware
from typing import Optional
import pandas as pd
import io
import os
from text_engine import Text_Search_Engine
app = FastAPI(title="CortexSearch", version="1.0", description="A flexible text search API with multiple FAISS index types and BM25 support.")
# Choose default index_type here: "flat", "ivf", or "hnsw"
store = Text_Search_Engine(index_type=os.getenv("INDEX_TYPE", "flat"))
try:
store.load()
except Exception:
pass
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
async def root():
return {"Status": "The CortexSearch API is live!!!"}
# -------------------------
# Column preview endpoint
# -------------------------
@app.post("/list_columns")
async def list_columns(file: UploadFile = File(...)):
"""
Upload a CSV and get available columns back.
Useful to preview before choosing columns to index.
"""
try:
contents = await file.read()
df = pd.read_csv(io.BytesIO(contents))
return {"available_columns": list(df.columns)}
except Exception as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
# -------------------------
# Health check endpoint
# -------------------------
@app.get("/health")
async def health():
return {"status": "ok", "rows_indexed": len(store.rows), "index_type": store.index_type}
# -------------------------
# Upload CSV (build fresh index)
# -------------------------
@app.post("/upload_csv")
async def upload_csv(file: UploadFile = File(...), columns: str = Form(...), index_type: Optional[str] = Form(None)):
#Upload CSV and specify columns (comma-separated) to combine into searchable text.
#Optional form field 'index_type' can be 'flat', 'ivf', or 'hnsw' to override engine default.
try:
contents = await file.read()
df = pd.read_csv(io.BytesIO(contents))
column_list = [c.strip() for c in columns.split(",") if c.strip()]
# Validate
for col in column_list:
if col not in df.columns:
return {
"status": "error",
"detail": f"Column '{col}' not found.",
"available_columns": list(df.columns),
}
rows = df.dropna(subset=column_list).to_dict(orient="records")
for r in rows:
r["_search_text"] = " ".join(str(r[col]) for col in column_list if r.get(col) is not None)
texts = [r["_search_text"] for r in rows]
if index_type:
store.index_type = index_type
store.encode_store(rows, texts)
return {"status": "success", "count": len(rows), "used_columns": column_list, "index_type": store.index_type}
except Exception as e:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
# -------------------------
# Add CSV (append new rows)
# -------------------------
@app.post("/add_csv")
async def add_csv(file: UploadFile = File(...), columns: str = Form(...)):
try:
contents = await file.read()
df = pd.read_csv(io.BytesIO(contents))
column_list = [c.strip() for c in columns.split(",") if c.strip()]
for col in column_list:
if col not in df.columns:
return {
"status": "error",
"detail": f"Column '{col}' not found.",
"available_columns": list(df.columns),
}
new_rows = df.dropna(subset=column_list).to_dict(orient="records")
for r in new_rows:
r["_search_text"] = " ".join(str(r[col]) for col in column_list if r.get(col) is not None)
new_texts = [r["_search_text"] for r in new_rows]
store.add_rows(new_rows, new_texts)
return {"status": "success", "added_count": len(new_rows), "used_columns": column_list, "total_rows": len(store.rows)}
except Exception as e:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
# -------------------------
# Search endpoint
# -------------------------
@app.get("/search")
async def search(
query: str,
top_k: int = 3,
mode: str = Query("semantic", enum=["semantic", "lexical", "hybrid"]),
alpha: float = 0.5,):
#mode: semantic | lexical | hybrid
#alpha: weight for semantic in hybrid (0..1)
try:
if mode == "semantic":
results = store.search(query, top_k=top_k)
elif mode == "lexical":
if store.bm25 is None:
return {"results": []}
tokenized_query = query.lower().split()
scores = store.bm25.get_scores(tokenized_query)
ranked = sorted(enumerate(scores), key=lambda x: x[1], reverse=True)[:top_k]
results = [{**store.rows[i], "score": float(score)} for i, score in ranked]
else:
results = store.hybrid_search(query, top_k=top_k, alpha=alpha)
return {"results": results}
except Exception as e:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
# -------------------------
# Delete all data
# -------------------------
@app.delete("/delete_data")
async def delete_data():
try:
store.clear_vdb()
return {"status": "success", "message": "Vector DB cleared"}
except Exception as e:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)) |