import os os.environ.setdefault("HF_HOME", "/tmp/huggingface") os.environ.setdefault("TRANSFORMERS_CACHE", "/tmp/huggingface") os.environ.setdefault("HF_HUB_CACHE", "/tmp/huggingface") os.environ.setdefault("SENTENCE_TRANSFORMERS_HOME", "/tmp/huggingface/st_models") import streamlit as st import openai import psycopg2 from collections import deque from sentence_transformers import SentenceTransformer import re # Setup client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY")) ll_model = 'gpt-4o-mini' # ── NEW: PostgreSQL connection ────────────────────────────── def get_db_connection(): return psycopg2.connect( host=os.getenv("RDS_HOST"), port=os.getenv("RDS_PORT", 5432), dbname=os.getenv("RDS_DB"), user=os.getenv("RDS_USER"), password=os.getenv("RDS_PASS") ) # ── NEW: BGE model ────────────────────────────────────────── model = SentenceTransformer('BAAI/bge-small-en-v1.5') def retrieve_summaries(query, top_k=40): try: embedding = get_embedding(query) conn = get_db_connection() cur = conn.cursor() cur.execute(""" SELECT id, case_id, chunk_index, chunk_summary, 1 - (embedding <=> %s::vector) AS similarity FROM public.case_chunks ORDER BY embedding <=> %s::vector LIMIT %s; """, [embedding, embedding, top_k]) rows = cur.fetchall() cur.close() conn.close() return [ { "id": row[0], "case_id": row[1], "chunk_index": row[2], "chunk_summary": row[3], "similarity": row[4] } for row in rows ] except Exception as e: st.error(f"Retrieve error: {e}") return [] # ── STEP 2: LLM picks best chunks based on summaries ─────── def rerank_with_llm(query, candidates, final_k=10): summary_list = "\n".join([ f"[ID: {c['id']}] Case: {c['case_id']} | Summary: {c['chunk_summary']}" for c in candidates ]) messages = [ {"role": "system", "content": "You are a legal research assistant. Given a user query and a list of document chunk summaries, " "select the most relevant chunk IDs that would best answer the query. " "Return ONLY a comma-separated list of IDs, nothing else. Example: 12,45,67,23" }, {"role": "user", "content": f"Query: {query}\n\n" f"Chunks:\n{summary_list}\n\n" f"Select the {final_k} most relevant chunk IDs." } ] try: resp = client.chat.completions.create( model=ll_model, messages=messages, temperature=0.0, max_tokens=200 ) raw = resp.choices[0].message.content.strip() selected_ids = [int(i.strip()) for i in raw.split(",") if i.strip().isdigit()] return selected_ids[:final_k] except Exception as e: st.error(f"Rerank error: {e}") # Fallback: just return top final_k by similarity return [c["id"] for c in candidates[:final_k]] # ── STEP 3: fetch full chunk_text for selected IDs only ──── def fetch_chunks_by_ids(selected_ids): try: conn = get_db_connection() cur = conn.cursor() cur.execute(""" SELECT id, case_id, chunk_index, chunk_text, chunk_summary FROM public.case_chunks WHERE id = ANY(%s); """, [selected_ids]) rows = cur.fetchall() cur.close() conn.close() return [ { "id": row[0], "case_id": row[1], "chunk_index": row[2], "chunk_text": row[3], "chunk_summary": row[4] } for row in rows ] except Exception as e: st.error(f"Fetch error: {e}") return [] def get_embedding(text): # BGE requires this prefix for queries prefixed = f"Represent this sentence for searching relevant passages: {text}" return model.encode(prefixed).tolist() st.title("AI Legal Assistant ⚖️") if "history" not in st.session_state: st.session_state.history = deque(maxlen=10) def get_rewritten_query(user_query): hist = list(st.session_state.history)[-4:] hist_text = "\n".join(f"{m['role']}: {m['content']}" for m in hist) messages = [ {"role": "system", "content": "You are a legal assistant that rewrites user queries into clear, context-aware queries for vector DB lookup. If its already clear then dont rewrite"}, {"role": "user", "content": f"History:\n{hist_text}\n\nNew query:\n{user_query}\n\n" "Rewrite if needed for clarity/search purposes. Otherwise, repeat exactly."} ] try: resp = client.chat.completions.create( model=ll_model, messages=messages, temperature=0.1, max_tokens=400 ) rewritten = resp.choices[0].message.content.strip() except Exception as e: st.error(f"Rewrite error: {e}") rewritten = user_query return rewritten # ── UPDATED: retrieve from pgvector ──────────────────────── # def retrieve_documents(query, top_k=10): # try: # embedding = get_embedding(query) # conn = get_db_connection() # cur = conn.cursor() # cur.execute(""" # SELECT # case_id, # chunk_index, # chunk_text, # chunk_summary, # 1 - (embedding <=> %s::vector) AS similarity # FROM public.case_chunks # ORDER BY embedding <=> %s::vector # LIMIT %s; # """, [embedding, embedding, top_k]) # rows = cur.fetchall() # cur.close() # conn.close() # # Format to match the rest of the app # docs = [] # for row in rows: # docs.append({ # "case_id": row[0], # "chunk_index": row[1], # "chunk_text": row[2], # "chunk_summary": row[3], # "similarity": row[4] # }) # return docs # except Exception as e: # st.error(f"Retrieve error: {e}") # return [] # ── COMBINED: full retrieval pipeline ────────────────────── def retrieve_documents(query, top_k=10): # 1. Get 4x summaries candidates = retrieve_summaries(query, top_k=top_k * 4) if not candidates: return [] # 2. LLM picks best IDs from summaries selected_ids = rerank_with_llm(query, candidates, final_k=top_k) if not selected_ids: return [] # 3. Fetch full text for selected chunks only docs = fetch_chunks_by_ids(selected_ids) return docs def clean_chunk_id(cid: str) -> str: cid = re.sub(r'_chunk.*$', '', cid) cid = cid.replace("_", " ").replace("-", " ") cid = " ".join(word.capitalize() for word in cid.split()) return cid # ── UPDATED: generate response with new doc structure ─────── def generate_response(user_query, docs): # Collect context from chunk_text context = "\n\n---\n\n".join(d['chunk_text'] for d in docs if d['chunk_text']) # Build sources source_links = {} for d in docs: case_id = d.get("case_id", "unknown") chunk_idx = d.get("chunk_index", "") text_preview = " ".join((d.get("chunk_text") or "").split()[:30]) if case_id == "constitution": display_name = f"Constitution (Chunk {chunk_idx})" else: display_name = f"Case Law: {text_preview}..." source_links[display_name] = d.get("chunk_text", "") source_links = dict(sorted(source_links.items())) messages = [ {"role": "system", "content": "You are a helpful legal assistant. Use the provided context from documents to answer the user's question. " "At the end of your answer, write a single line starting with 'Source: ' followed by the sources used. " "Formatting rules:\n" "- For Constitution: show the chunk number.\n" "- For Case law: show first ~30 words of the case text.\n" "- Do not use technical terms like 'chunk'. Present sources in a human-friendly way.\n" "If multiple are used, separate them with commas."} ] messages.extend(list(st.session_state.history)) messages.append({"role": "user", "content": f"Context:\n{context}\n\n" f"Sources:\n{', '.join(source_links.keys())}\n\n" f"Question:\n{user_query}"}) try: resp = client.chat.completions.create( model=ll_model, messages=messages, temperature=0.1, max_tokens=900 ) reply = resp.choices[0].message.content.strip() except Exception as e: st.error(f"Response error: {e}") reply = "Sorry, I encountered an error generating the answer." if source_links: clean_sources = ", ".join(source_links.keys()) if "Source:" not in reply: reply += f"\n\nSource: {clean_sources}" st.session_state.history.append({"role": "assistant", "content": reply}) st.markdown(reply) if source_links: st.write("### Sources") for name, text in source_links.items(): with st.expander(name): st.write(text) return reply # Chat UI with st.form("chat_input", clear_on_submit=True): user_input = st.text_input("You:", "") submit = st.form_submit_button("Send") if submit and user_input: st.session_state.history.append({"role": "user", "content": user_input}) rewritten = get_rewritten_query(user_input) docs = retrieve_documents(rewritten) assistant_reply = generate_response(rewritten, docs) c = 0 st.markdown("---") for msg in reversed(st.session_state.history): c += 1 if msg["role"] == "user": st.markdown(f"**You:** {msg['content']}") else: st.markdown(f"**Legal Assistant:** {msg['content']}") if c ^ 1: st.markdown("---")