| | from fastapi import FastAPI |
| | from pydantic import BaseModel |
| | from sentence_transformers import SentenceTransformer |
| | import chromadb |
| | from fastapi.middleware.cors import CORSMiddleware |
| | import uvicorn |
| | import requests |
| | from itertools import combinations |
| | import sqlite3 |
| | import pandas as pd |
| | import os |
| | import time |
| |
|
| | |
| | app = FastAPI() |
| |
|
| | origins = [ |
| | "http://localhost:5173", |
| | "localhost:5173" |
| | ] |
| |
|
| | app.add_middleware( |
| | CORSMiddleware, |
| | allow_origins=origins, |
| | allow_credentials=True, |
| | allow_methods=["*"], |
| | allow_headers=["*"] |
| | ) |
| |
|
| | |
| | model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') |
| | client = chromadb.PersistentClient(path='./chromadb') |
| | collection = client.get_or_create_collection(name="symptomsvector") |
| |
|
| | |
| | def init_db(): |
| | conn = sqlite3.connect("diseases_symptoms.db") |
| | cursor = conn.cursor() |
| | cursor.execute(''' |
| | CREATE TABLE IF NOT EXISTS diseases ( |
| | id INTEGER PRIMARY KEY, |
| | name TEXT, |
| | symptoms TEXT, |
| | treatments TEXT |
| | ) |
| | ''') |
| | conn.commit() |
| | return conn |
| |
|
| | |
| | if not os.path.exists("diseases_symptoms.db"): |
| | conn = init_db() |
| | df = pd.read_csv("hf://datasets/QuyenAnhDE/Diseases_Symptoms/Diseases_Symptoms.csv") |
| | df['Symptoms'] = df['Symptoms'].str.split(',').apply(lambda x: [s.strip() for s in x]) |
| |
|
| | for _, row in df.iterrows(): |
| | symptoms_str = ",".join(row['Symptoms']) |
| | cursor = conn.cursor() |
| | cursor.execute("INSERT INTO diseases (name, symptoms, treatments) VALUES (?, ?, ?)", |
| | (row['Name'], symptoms_str, row.get('Treatments', ''))) |
| | conn.commit() |
| | conn.close() |
| |
|
| | class SymptomQuery(BaseModel): |
| | symptom: str |
| |
|
| | |
| | def fetch_diseases_by_symptoms(matching_symptoms): |
| | conn = sqlite3.connect("diseases_symptoms.db") |
| | cursor = conn.cursor() |
| | disease_list = [] |
| | unique_symptoms_list = [] |
| | matching_symptom_str = ','.join(matching_symptoms) |
| |
|
| | |
| | for row in cursor.execute("SELECT name, symptoms, treatments FROM diseases WHERE symptoms LIKE ?", |
| | (f'%{matching_symptom_str}%',)): |
| | disease_info = { |
| | 'Disease': row[0], |
| | 'Symptoms': row[1].split(','), |
| | 'Treatments': row[2] |
| | } |
| | disease_list.append(disease_info) |
| |
|
| | |
| | for symptom in row[1].split(','): |
| | symptom_lower = symptom.strip().lower() |
| | if symptom_lower not in unique_symptoms_list: |
| | unique_symptoms_list.append(symptom_lower) |
| |
|
| | conn.close() |
| | return disease_list, unique_symptoms_list |
| |
|
| | @app.post("/find_matching_symptoms") |
| | def find_matching_symptoms(query: SymptomQuery): |
| | symptoms = query.symptom.split(',') |
| | all_results = [] |
| |
|
| | for symptom in symptoms: |
| | symptom = symptom.strip() |
| | query_embedding = model.encode([symptom]) |
| |
|
| | |
| | results = collection.query( |
| | query_embeddings=query_embedding.tolist(), |
| | n_results=3 |
| | ) |
| | all_results.extend(results['documents'][0]) |
| |
|
| | matching_symptoms = list(dict.fromkeys(all_results)) |
| | return {"matching_symptoms": matching_symptoms} |
| |
|
| | @app.post("/find_disease_list") |
| | def find_disease_list(query: SymptomQuery): |
| | |
| | selected_symptoms = [symptom.strip().lower() for symptom in query.symptom.split(',')] |
| | all_selected_symptoms.update(selected_symptoms) |
| | all_results = [] |
| |
|
| | for symptom in selected_symptoms: |
| | |
| | query_embedding = model.encode([symptom]) |
| |
|
| | |
| | results = collection.query( |
| | query_embeddings=query_embedding.tolist(), |
| | n_results=5 |
| | ) |
| | |
| | all_results.extend(results['documents'][0]) |
| |
|
| | |
| | matching_symptoms = list(dict.fromkeys(all_results)) |
| |
|
| | conn = sqlite3.connect("diseases_symptoms.db") |
| | cursor = conn.cursor() |
| |
|
| | disease_list = [] |
| | unique_symptoms_set = set() |
| |
|
| | |
| | for row in cursor.execute("SELECT name, symptoms, treatments FROM diseases"): |
| | disease_name = row[0] |
| | disease_symptoms = [symptom.strip().lower() for symptom in row[1].split(',')] |
| | treatments = row[2] |
| |
|
| | |
| | matched_symptoms = [symptom for symptom in matching_symptoms if symptom in disease_symptoms] |
| |
|
| | if matched_symptoms: |
| | disease_info = { |
| | 'Disease': disease_name, |
| | 'Symptoms': disease_symptoms, |
| | 'Treatments': treatments |
| | } |
| | disease_list.append(disease_info) |
| |
|
| | |
| | for symptom in disease_symptoms: |
| | if symptom not in selected_symptoms: |
| | unique_symptoms_set.add(symptom) |
| |
|
| | conn.close() |
| |
|
| | |
| | unique_symptoms_list = sorted(unique_symptoms_set) |
| |
|
| | return { |
| | "disease_list": disease_list, |
| | "unique_symptoms_list": unique_symptoms_list |
| | } |
| |
|
| |
|
| |
|
| | class SelectedSymptomsQuery(BaseModel): |
| | selected_symptoms: list |
| |
|
| | |
| | all_selected_symptoms = set() |
| |
|
| | @app.post("/find_disease") |
| | def find_disease(query: SelectedSymptomsQuery): |
| | |
| | new_symptoms = [symptom.strip().lower() for symptom in query.selected_symptoms] |
| | all_selected_symptoms.update(new_symptoms) |
| |
|
| | conn = sqlite3.connect("diseases_symptoms.db") |
| | cursor = conn.cursor() |
| |
|
| | disease_list = [] |
| | unique_symptoms_set = set() |
| |
|
| | |
| | for row in cursor.execute("SELECT name, symptoms, treatments FROM diseases"): |
| | disease_name = row[0] |
| | disease_symptoms = [symptom.strip().lower() for symptom in row[1].split(',')] |
| | treatments = row[2] |
| |
|
| | |
| | matched_symptoms = [symptom for symptom in all_selected_symptoms if symptom in disease_symptoms] |
| |
|
| | |
| | if len(matched_symptoms) == len(all_selected_symptoms): |
| | disease_info = { |
| | 'Disease': disease_name, |
| | 'Symptoms': disease_symptoms, |
| | 'Treatments': treatments |
| | } |
| | disease_list.append(disease_info) |
| |
|
| | |
| | for symptom in disease_symptoms: |
| | if symptom not in all_selected_symptoms: |
| | unique_symptoms_set.add(symptom) |
| |
|
| | conn.close() |
| |
|
| | |
| | unique_symptoms_list = sorted(unique_symptoms_set) |
| |
|
| | return { |
| | "unique_symptoms_list": unique_symptoms_list, |
| | "all_selected_symptoms": list(all_selected_symptoms), |
| | "disease_list": disease_list |
| | } |
| |
|
| | class DiseaseDetail(BaseModel): |
| | Disease: str |
| | Symptoms: list |
| | Treatments: str |
| | MatchCount: int |
| |
|
| | @app.post("/pass2llm") |
| | def pass2llm(query: DiseaseDetail): |
| | headers = { |
| | "Authorization": "Bearer 2npJaJjnLBj1RGPcGf0QiyAAJHJ_5qqtw2divkpoAipqN9WLG", |
| | "Ngrok-Version": "2" |
| | } |
| | response = requests.get("https://api.ngrok.com/endpoints", headers=headers) |
| |
|
| | if response.status_code == 200: |
| | llm_api_response = response.json() |
| | public_url = llm_api_response['endpoints'][0]['public_url'] |
| | prompt = f"Here is a list of diseases and their details: {query}. Please generate a summary." |
| |
|
| | llm_headers = {"Content-Type": "application/json"} |
| | llm_payload = {"model": "llama3", "prompt": prompt, "stream": False} |
| | llm_response = requests.post(f"{public_url}/api/generate", headers=llm_headers, json=llm_payload) |
| |
|
| | if llm_response.status_code == 200: |
| | llm_response_json = llm_response.json() |
| | return {"message": "Successfully passed to LLM!", "llm_response": llm_response_json.get("response")} |
| | else: |
| | return {"message": "Failed to get response from LLM!", "error": llm_response.text} |
| | else: |
| | return {"message": "Failed to get public URL from Ngrok!", "error": response.text} |
| |
|
| |
|
| |
|
| |
|
| | @app.post("/trigger-reload") |
| | async def trigger_reload(): |
| | global all_selected_symptoms |
| | all_selected_symptoms.clear() |
| | return "cleared" |
| |
|
| | |
| | |
| | |
| |
|