| import os |
| import torch |
| import streamlit as st |
| from PIL import Image |
| from transformers import AutoModelForImageClassification, AutoImageProcessor |
| from groq import Groq |
|
|
| |
| st.set_page_config(page_title="DermaBot - AI Skin Disease Detector", page_icon="🩺", layout="wide") |
|
|
| |
| MODEL_NAME = "Jayanth2002/dinov2-base-finetuned-SkinDisease" |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| model = AutoModelForImageClassification.from_pretrained(MODEL_NAME).to(DEVICE) |
| processor = AutoImageProcessor.from_pretrained(MODEL_NAME) |
|
|
| |
| client = Groq(api_key="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") |
|
|
| |
| if "disease_name" not in st.session_state: |
| st.session_state.disease_name = None |
| if "disease_info" not in st.session_state: |
| st.session_state.disease_info = None |
|
|
| |
| def predict_skin_disease(image): |
| image = image.convert("RGB") |
| inputs = processor(images=image, return_tensors="pt").to(DEVICE) |
|
|
| with torch.no_grad(): |
| outputs = model(**inputs) |
|
|
| logits = outputs.logits |
| predicted_class_idx = logits.argmax(-1).item() |
| predicted_label = model.config.id2label[predicted_class_idx] |
| |
| return predicted_label |
|
|
| |
| def get_disease_info(disease_name): |
| prompt = f"Provide a detailed explanation about the skin disease '{disease_name}', including description of disease, causes, precausions, risk and treatment options." |
|
|
| chat_completion = client.chat.completions.create( |
| messages=[{"role": "user", "content": prompt}], |
| model="llama-3.3-70b-versatile", |
| ) |
|
|
| return chat_completion.choices[0].message.content |
|
|
| |
| def chatbot_response(disease_name, user_query): |
| if not disease_name: |
| return "Please upload an image and detect the disease first." |
|
|
| prompt = f"The detected skin disease is '{disease_name}'. {user_query}" |
| |
| chat_completion = client.chat.completions.create( |
| messages=[{"role": "user", "content": prompt}], |
| model="llama-3.3-70b-versatile", |
| ) |
|
|
| return chat_completion.choices[0].message.content |
|
|
| |
| st.image("https://huggingface.co/spaces/your-huggingface-space/logo.png", width=200) |
| st.title("🩺 DermaBot - AI Skin Disease Detector") |
| st.write("Upload an image of a skin condition to get a diagnosis and ask questions about it.") |
|
|
| |
| uploaded_image = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) |
|
|
| if uploaded_image: |
| image = Image.open(uploaded_image) |
| st.image(image, caption="Uploaded Image", use_container_width=True) |
|
|
| |
| if st.button("Detect Disease"): |
| with st.spinner("Analyzing..."): |
| disease_name = predict_skin_disease(image) |
| disease_info = get_disease_info(disease_name) |
|
|
| |
| st.session_state.disease_name = disease_name |
| st.session_state.disease_info = disease_info |
|
|
| |
| if st.session_state.disease_name: |
| st.success(f"**Detected Disease:** {st.session_state.disease_name}") |
| st.write(f"**Details:** {st.session_state.disease_info}") |
|
|
| |
| st.subheader("💬 Ask DermaBot about this disease") |
|
|
| user_query = st.text_input("Ask about the detected disease:") |
|
|
| if st.button("Ask"): |
| with st.spinner("Thinking..."): |
| response = chatbot_response(st.session_state.disease_name, user_query) |
| st.write(response) |
|
|
| st.markdown("---") |
| st.write("🔍 Powered by **AI & Groq API** | © 2025 DermaBot") |
|
|