| import gradio as gr |
| import tensorflow as tf |
| import numpy as np |
| from PIL import Image |
| import google.generativeai as genai |
| import os |
| import markdown2 |
|
|
| |
| model = tf.saved_model.load('model') |
| labels = ['cataract', 'diabetic_retinopathy', 'glaucoma', 'normal'] |
|
|
| |
| genai.configure(api_key=os.getenv("GOOGLE_API_KEY")) |
|
|
| |
| def get_disease_detail(disease): |
| prompt = ( |
| "Create a text congratulating on healthy eyes with tips to keep them healthy." |
| if disease == "normal" else |
| f"Diagnosis: {disease}\n\n" |
| f"What is {disease}?\nCauses and suggestions to prevent {disease}." |
| ) |
| try: |
| response = genai.GenerativeModel("gemini-1.5-flash").generate_content(prompt) |
| return markdown2.markdown(response.text.strip() if response and response.text else "No response.") |
| except Exception as e: |
| return f"Error: {e}" |
|
|
| |
| def predict_image(image): |
| img_array = np.expand_dims(np.array(image.resize((224, 224))).astype(np.float32) / 255.0, axis=0) |
| predictions = model.signatures['serving_default'](tf.convert_to_tensor(img_array, dtype=tf.float32))['output_0'] |
|
|
| top_label = labels[np.argmax(predictions.numpy())] |
| explanation = get_disease_detail(top_label) |
|
|
| return {top_label: predictions.numpy().max()}, explanation |
|
|
| |
| example_images = [[f"exp_eye_images/{img}"] for img in ["0_right_h.png", "03fd50da928d_dr.png", "108_right_h.png", "1062_right_c.png", "1084_right_c.png", "image_1002_g.jpg"]] |
|
|
| |
| interface = gr.Interface( |
| fn=predict_image, |
| inputs=gr.Image(type="pil"), |
| outputs=[gr.Label(num_top_classes=1, label="Prediction"), gr.HTML(label="Explanation", elem_classes=["scrollable-html"])], |
| examples=example_images, |
| title="DR Predictor", |
| description=("Upload an eye fundus image, and the model predicts the condition."), |
| allow_flagging="never", |
| css=".scrollable-html {height: 206px; overflow-y: auto; border: 1px solid #ccc; padding: 10px; box-sizing: border-box;}" |
| ) |
|
|
| interface.launch(share=True) |
|
|