File size: 5,660 Bytes
e4b6894
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c46b826
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
# main.py
from fastapi import FastAPI, HTTPException, status, File, UploadFile, Form, Query
from fastapi.middleware.cors import CORSMiddleware
from typing import Optional
import pandas as pd
import io
import os
from text_engine import Text_Search_Engine

app = FastAPI(title="CortexSearch", version="1.0", description="A flexible text search API with multiple FAISS index types and BM25 support.")

# Choose default index_type here: "flat", "ivf", or "hnsw"
store = Text_Search_Engine(index_type=os.getenv("INDEX_TYPE", "flat"))
try:
    store.load()
except Exception:
    pass

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"], 
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


@app.get("/")
async def root():
    return {"Status": "The CortexSearch API is live!!!"}


# -------------------------
# Column preview endpoint
# -------------------------
@app.post("/list_columns")
async def list_columns(file: UploadFile = File(...)):
    """
    Upload a CSV and get available columns back.
    Useful to preview before choosing columns to index.
    """
    try:
        contents = await file.read()
        df = pd.read_csv(io.BytesIO(contents))
        return {"available_columns": list(df.columns)}
    except Exception as e:
        raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))


# -------------------------
# Health check endpoint
# -------------------------
@app.get("/health")
async def health():
    return {"status": "ok", "rows_indexed": len(store.rows), "index_type": store.index_type}


# -------------------------
# Upload CSV (build fresh index)
# -------------------------
@app.post("/upload_csv")
async def upload_csv(file: UploadFile = File(...), columns: str = Form(...), index_type: Optional[str] = Form(None)):
    #Upload CSV and specify columns (comma-separated) to combine into searchable text.
    #Optional form field 'index_type' can be 'flat', 'ivf', or 'hnsw' to override engine default.
    try:
        contents = await file.read()
        df = pd.read_csv(io.BytesIO(contents))

        column_list = [c.strip() for c in columns.split(",") if c.strip()]
        # Validate
        for col in column_list:
            if col not in df.columns:
                return {
                    "status": "error",
                    "detail": f"Column '{col}' not found.",
                    "available_columns": list(df.columns),
                }

        rows = df.dropna(subset=column_list).to_dict(orient="records")
        for r in rows:
            r["_search_text"] = " ".join(str(r[col]) for col in column_list if r.get(col) is not None)

        texts = [r["_search_text"] for r in rows]

        if index_type:
            store.index_type = index_type

        store.encode_store(rows, texts)
        return {"status": "success", "count": len(rows), "used_columns": column_list, "index_type": store.index_type}
    except Exception as e:
        raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))


# -------------------------
# Add CSV (append new rows)
# -------------------------
@app.post("/add_csv")
async def add_csv(file: UploadFile = File(...), columns: str = Form(...)):
    try:
        contents = await file.read()
        df = pd.read_csv(io.BytesIO(contents))

        column_list = [c.strip() for c in columns.split(",") if c.strip()]
        for col in column_list:
            if col not in df.columns:
                return {
                    "status": "error",
                    "detail": f"Column '{col}' not found.",
                    "available_columns": list(df.columns),
                }

        new_rows = df.dropna(subset=column_list).to_dict(orient="records")
        for r in new_rows:
            r["_search_text"] = " ".join(str(r[col]) for col in column_list if r.get(col) is not None)

        new_texts = [r["_search_text"] for r in new_rows]

        store.add_rows(new_rows, new_texts)

        return {"status": "success", "added_count": len(new_rows), "used_columns": column_list, "total_rows": len(store.rows)}
    except Exception as e:
        raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))


# -------------------------
# Search endpoint
# -------------------------
@app.get("/search")
async def search(
    query: str,
    top_k: int = 3,
    mode: str = Query("semantic", enum=["semantic", "lexical", "hybrid"]),
    alpha: float = 0.5,):
    #mode: semantic | lexical | hybrid
    #alpha: weight for semantic in hybrid (0..1)
    try:
        if mode == "semantic":
            results = store.search(query, top_k=top_k)
        elif mode == "lexical":
            if store.bm25 is None:
                return {"results": []}
            tokenized_query = query.lower().split()
            scores = store.bm25.get_scores(tokenized_query)
            ranked = sorted(enumerate(scores), key=lambda x: x[1], reverse=True)[:top_k]
            results = [{**store.rows[i], "score": float(score)} for i, score in ranked]
        else:
            results = store.hybrid_search(query, top_k=top_k, alpha=alpha)

        return {"results": results}
    except Exception as e:
        raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))


# -------------------------
# Delete all data
# -------------------------
@app.delete("/delete_data")
async def delete_data():
    try:
        store.clear_vdb()
        return {"status": "success", "message": "Vector DB cleared"}
    except Exception as e:
        raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))