| from fastapi import FastAPI |
| from pydantic import BaseModel |
| from sentence_transformers import SentenceTransformer |
| import chromadb |
| from fastapi.middleware.cors import CORSMiddleware |
| import uvicorn |
| import requests |
| from itertools import combinations |
| import sqlite3 |
| import pandas as pd |
| import os |
| import time |
|
|
| |
| app = FastAPI() |
|
|
| origins = [ |
| "http://localhost:5173", |
| "localhost:5173" |
| ] |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=origins, |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"] |
| ) |
|
|
| |
| model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') |
| client = chromadb.PersistentClient(path='./chromadb') |
| collection = client.get_or_create_collection(name="symptomsvector") |
|
|
| |
| def init_db(): |
| conn = sqlite3.connect("diseases_symptoms.db") |
| cursor = conn.cursor() |
| cursor.execute(''' |
| CREATE TABLE IF NOT EXISTS diseases ( |
| id INTEGER PRIMARY KEY, |
| name TEXT, |
| symptoms TEXT, |
| treatments TEXT |
| ) |
| ''') |
| conn.commit() |
| return conn |
|
|
| |
| if not os.path.exists("diseases_symptoms.db"): |
| conn = init_db() |
| df = pd.read_csv("hf://datasets/QuyenAnhDE/Diseases_Symptoms/Diseases_Symptoms.csv") |
| df['Symptoms'] = df['Symptoms'].str.split(',').apply(lambda x: [s.strip() for s in x]) |
|
|
| for _, row in df.iterrows(): |
| symptoms_str = ",".join(row['Symptoms']) |
| cursor = conn.cursor() |
| cursor.execute("INSERT INTO diseases (name, symptoms, treatments) VALUES (?, ?, ?)", |
| (row['Name'], symptoms_str, row.get('Treatments', ''))) |
| conn.commit() |
| conn.close() |
|
|
| class SymptomQuery(BaseModel): |
| symptom: str |
|
|
| |
| def fetch_diseases_by_symptoms(matching_symptoms): |
| conn = sqlite3.connect("diseases_symptoms.db") |
| cursor = conn.cursor() |
| disease_list = [] |
| unique_symptoms_list = [] |
| matching_symptom_str = ','.join(matching_symptoms) |
|
|
| |
| for row in cursor.execute("SELECT name, symptoms, treatments FROM diseases WHERE symptoms LIKE ?", |
| (f'%{matching_symptom_str}%',)): |
| disease_info = { |
| 'Disease': row[0], |
| 'Symptoms': row[1].split(','), |
| 'Treatments': row[2] |
| } |
| disease_list.append(disease_info) |
|
|
| |
| for symptom in row[1].split(','): |
| symptom_lower = symptom.strip().lower() |
| if symptom_lower not in unique_symptoms_list: |
| unique_symptoms_list.append(symptom_lower) |
|
|
| conn.close() |
| return disease_list, unique_symptoms_list |
|
|
| @app.post("/find_matching_symptoms") |
| def find_matching_symptoms(query: SymptomQuery): |
| symptoms = query.symptom.split(',') |
| all_results = [] |
|
|
| for symptom in symptoms: |
| symptom = symptom.strip() |
| query_embedding = model.encode([symptom]) |
|
|
| |
| results = collection.query( |
| query_embeddings=query_embedding.tolist(), |
| n_results=3 |
| ) |
| all_results.extend(results['documents'][0]) |
|
|
| matching_symptoms = list(dict.fromkeys(all_results)) |
| return {"matching_symptoms": matching_symptoms} |
|
|
| @app.post("/find_disease_list") |
| def find_disease_list(query: SymptomQuery): |
| |
| selected_symptoms = [symptom.strip().lower() for symptom in query.symptom.split(',')] |
| all_selected_symptoms.update(selected_symptoms) |
| all_results = [] |
|
|
| for symptom in selected_symptoms: |
| |
| query_embedding = model.encode([symptom]) |
|
|
| |
| results = collection.query( |
| query_embeddings=query_embedding.tolist(), |
| n_results=5 |
| ) |
| |
| all_results.extend(results['documents'][0]) |
|
|
| |
| matching_symptoms = list(dict.fromkeys(all_results)) |
|
|
| conn = sqlite3.connect("diseases_symptoms.db") |
| cursor = conn.cursor() |
|
|
| disease_list = [] |
| unique_symptoms_set = set() |
|
|
| |
| for row in cursor.execute("SELECT name, symptoms, treatments FROM diseases"): |
| disease_name = row[0] |
| disease_symptoms = [symptom.strip().lower() for symptom in row[1].split(',')] |
| treatments = row[2] |
|
|
| |
| matched_symptoms = [symptom for symptom in matching_symptoms if symptom in disease_symptoms] |
|
|
| if matched_symptoms: |
| disease_info = { |
| 'Disease': disease_name, |
| 'Symptoms': disease_symptoms, |
| 'Treatments': treatments |
| } |
| disease_list.append(disease_info) |
|
|
| |
| for symptom in disease_symptoms: |
| if symptom not in selected_symptoms: |
| unique_symptoms_set.add(symptom) |
|
|
| conn.close() |
|
|
| |
| unique_symptoms_list = sorted(unique_symptoms_set) |
|
|
| return { |
| "disease_list": disease_list, |
| "unique_symptoms_list": unique_symptoms_list |
| } |
|
|
|
|
|
|
| class SelectedSymptomsQuery(BaseModel): |
| selected_symptoms: list |
|
|
| |
| all_selected_symptoms = set() |
|
|
| @app.post("/find_disease") |
| def find_disease(query: SelectedSymptomsQuery): |
| |
| new_symptoms = [symptom.strip().lower() for symptom in query.selected_symptoms] |
| all_selected_symptoms.update(new_symptoms) |
|
|
| conn = sqlite3.connect("diseases_symptoms.db") |
| cursor = conn.cursor() |
|
|
| disease_list = [] |
| unique_symptoms_set = set() |
|
|
| |
| for row in cursor.execute("SELECT name, symptoms, treatments FROM diseases"): |
| disease_name = row[0] |
| disease_symptoms = [symptom.strip().lower() for symptom in row[1].split(',')] |
| treatments = row[2] |
|
|
| |
| matched_symptoms = [symptom for symptom in all_selected_symptoms if symptom in disease_symptoms] |
|
|
| |
| if len(matched_symptoms) == len(all_selected_symptoms): |
| disease_info = { |
| 'Disease': disease_name, |
| 'Symptoms': disease_symptoms, |
| 'Treatments': treatments |
| } |
| disease_list.append(disease_info) |
|
|
| |
| for symptom in disease_symptoms: |
| if symptom not in all_selected_symptoms: |
| unique_symptoms_set.add(symptom) |
|
|
| conn.close() |
|
|
| |
| unique_symptoms_list = sorted(unique_symptoms_set) |
|
|
| return { |
| "unique_symptoms_list": unique_symptoms_list, |
| "all_selected_symptoms": list(all_selected_symptoms), |
| "disease_list": disease_list |
| } |
|
|
| class DiseaseDetail(BaseModel): |
| Disease: str |
| Symptoms: list |
| Treatments: str |
| MatchCount: int |
|
|
| @app.post("/pass2llm") |
| def pass2llm(query: DiseaseDetail): |
| headers = { |
| "Authorization": "Bearer 2npJaJjnLBj1RGPcGf0QiyAAJHJ_5qqtw2divkpoAipqN9WLG", |
| "Ngrok-Version": "2" |
| } |
| response = requests.get("https://api.ngrok.com/endpoints", headers=headers) |
|
|
| if response.status_code == 200: |
| llm_api_response = response.json() |
| public_url = llm_api_response['endpoints'][0]['public_url'] |
| prompt = f"Here is a list of diseases and their details: {query}. Please generate a summary." |
|
|
| llm_headers = {"Content-Type": "application/json"} |
| llm_payload = {"model": "llama3", "prompt": prompt, "stream": False} |
| llm_response = requests.post(f"{public_url}/api/generate", headers=llm_headers, json=llm_payload) |
|
|
| if llm_response.status_code == 200: |
| llm_response_json = llm_response.json() |
| return {"message": "Successfully passed to LLM!", "llm_response": llm_response_json.get("response")} |
| else: |
| return {"message": "Failed to get response from LLM!", "error": llm_response.text} |
| else: |
| return {"message": "Failed to get public URL from Ngrok!", "error": response.text} |
|
|
|
|
|
|
|
|
| @app.post("/trigger-reload") |
| async def trigger_reload(): |
| global all_selected_symptoms |
| all_selected_symptoms.clear() |
| return "cleared" |
|
|
| |
| |
| |
|
|