| | import re |
| | import base64 |
| | import gradio as gr |
| | from pathlib import Path |
| | import time |
| | import shutil |
| | from typing import AsyncGenerator, List, Optional, Tuple |
| | from gradio import ChatMessage |
| |
|
| |
|
| | class ChatInterface: |
| | """ |
| | A chat interface for interacting with a medical AI agent through Gradio. |
| | |
| | Handles file uploads, message processing, and chat history management. |
| | Supports both regular image files and DICOM medical imaging files. |
| | """ |
| |
|
| | def __init__(self, agent, tools_dict): |
| | """ |
| | Initialize the chat interface. |
| | |
| | Args: |
| | agent: The medical AI agent to handle requests |
| | tools_dict (dict): Dictionary of available tools for image processing |
| | """ |
| | self.agent = agent |
| | self.tools_dict = tools_dict |
| | self.upload_dir = Path("temp") |
| | self.upload_dir.mkdir(exist_ok=True) |
| | self.current_thread_id = None |
| | |
| | self.original_file_path = None |
| | self.display_file_path = None |
| |
|
| | def handle_upload(self, file_path: str) -> str: |
| | """ |
| | Handle new file upload and set appropriate paths. |
| | |
| | Args: |
| | file_path (str): Path to the uploaded file |
| | |
| | Returns: |
| | str: Display path for UI, or None if no file uploaded |
| | """ |
| | if not file_path: |
| | return None |
| |
|
| | source = Path(file_path) |
| | timestamp = int(time.time()) |
| |
|
| | |
| | suffix = source.suffix.lower() |
| | saved_path = self.upload_dir / f"upload_{timestamp}{suffix}" |
| | shutil.copy2(file_path, saved_path) |
| | self.original_file_path = str(saved_path) |
| |
|
| | |
| | if suffix == ".dcm": |
| | output, _ = self.tools_dict["DicomProcessorTool"]._run(str(saved_path)) |
| | self.display_file_path = output["image_path"] |
| | else: |
| | self.display_file_path = str(saved_path) |
| |
|
| | return self.display_file_path |
| |
|
| | def add_message( |
| | self, message: str, display_image: str, history: List[dict] |
| | ) -> Tuple[List[dict], gr.Textbox]: |
| | """ |
| | Add a new message to the chat history. |
| | |
| | Args: |
| | message (str): Text message to add |
| | display_image (str): Path to image being displayed |
| | history (List[dict]): Current chat history |
| | |
| | Returns: |
| | Tuple[List[dict], gr.Textbox]: Updated history and textbox component |
| | """ |
| | image_path = self.original_file_path or display_image |
| | if image_path is not None: |
| | history.append({"role": "user", "content": {"path": image_path}}) |
| | if message is not None: |
| | history.append({"role": "user", "content": message}) |
| | return history, gr.Textbox(value=message, interactive=False) |
| |
|
| | async def process_message( |
| | self, message: str, display_image: Optional[str], chat_history: List[ChatMessage] |
| | ) -> AsyncGenerator[Tuple[List[ChatMessage], Optional[str], str], None]: |
| | """ |
| | Process a message and generate responses. |
| | |
| | Args: |
| | message (str): User message to process |
| | display_image (Optional[str]): Path to currently displayed image |
| | chat_history (List[ChatMessage]): Current chat history |
| | |
| | Yields: |
| | Tuple[List[ChatMessage], Optional[str], str]: Updated chat history, display path, and empty string |
| | """ |
| | chat_history = chat_history or [] |
| |
|
| | |
| | if not self.current_thread_id: |
| | self.current_thread_id = str(time.time()) |
| |
|
| | messages = [] |
| | image_path = self.original_file_path or display_image |
| |
|
| | if image_path is not None: |
| | |
| | messages.append({"role": "user", "content": f"image_path: {image_path}"}) |
| |
|
| | |
| | with open(image_path, "rb") as img_file: |
| | img_base64 = base64.b64encode(img_file.read()).decode("utf-8") |
| |
|
| | messages.append( |
| | { |
| | "role": "user", |
| | "content": [ |
| | { |
| | "type": "image_url", |
| | "image_url": {"url": f"data:image/jpeg;base64,{img_base64}"}, |
| | } |
| | ], |
| | } |
| | ) |
| |
|
| | if message is not None: |
| | messages.append({"role": "user", "content": [{"type": "text", "text": message}]}) |
| |
|
| | try: |
| | for event in self.agent.workflow.stream( |
| | {"messages": messages}, {"configurable": {"thread_id": self.current_thread_id}} |
| | ): |
| | if isinstance(event, dict): |
| | if "process" in event: |
| | content = event["process"]["messages"][-1].content |
| | if content: |
| | content = re.sub(r"temp/[^\s]*", "", content) |
| | chat_history.append(ChatMessage(role="assistant", content=content)) |
| | yield chat_history, self.display_file_path, "" |
| |
|
| | elif "execute" in event: |
| | for message in event["execute"]["messages"]: |
| | tool_name = message.name |
| | tool_result = eval(message.content)[0] |
| |
|
| | if tool_result: |
| | metadata = {"title": f"🖼️ Image from tool: {tool_name}"} |
| | formatted_result = " ".join( |
| | line.strip() for line in str(tool_result).splitlines() |
| | ).strip() |
| | metadata["description"] = formatted_result |
| | chat_history.append( |
| | ChatMessage( |
| | role="assistant", |
| | content=formatted_result, |
| | metadata=metadata, |
| | ) |
| | ) |
| |
|
| | |
| | if tool_name == "image_visualizer": |
| | self.display_file_path = tool_result["image_path"] |
| | chat_history.append( |
| | ChatMessage( |
| | role="assistant", |
| | |
| | content={"path": self.display_file_path}, |
| | ) |
| | ) |
| |
|
| | yield chat_history, self.display_file_path, "" |
| |
|
| | except Exception as e: |
| | chat_history.append( |
| | ChatMessage( |
| | role="assistant", content=f"❌ Error: {str(e)}", metadata={"title": "Error"} |
| | ) |
| | ) |
| | yield chat_history, self.display_file_path, "" |
| |
|
| |
|
| | def create_demo(agent, tools_dict): |
| | """ |
| | Create a Gradio demo interface for the medical AI agent. |
| | |
| | Args: |
| | agent: The medical AI agent to handle requests |
| | tools_dict (dict): Dictionary of available tools for image processing |
| | |
| | Returns: |
| | gr.Blocks: Gradio Blocks interface |
| | """ |
| | interface = ChatInterface(agent, tools_dict) |
| |
|
| | with gr.Blocks(theme=gr.themes.Soft()) as demo: |
| | with gr.Column(): |
| | gr.Markdown( |
| | """ |
| | # DAEASE |
| | Medical Image Reasoning Agent. |
| | """ |
| | ) |
| |
|
| | with gr.Row(): |
| | with gr.Column(scale=3): |
| | chatbot = gr.Chatbot( |
| | [], |
| | height=800, |
| | container=True, |
| | show_label=True, |
| | elem_classes="chat-box", |
| | type="messages", |
| | label="Agent", |
| | avatar_images=( |
| | None, |
| | "assets/medrax_logo.jpg", |
| | ), |
| | ) |
| | with gr.Row(): |
| | with gr.Column(scale=3): |
| | txt = gr.Textbox( |
| | show_label=False, |
| | placeholder="Ask about the X-ray...", |
| | container=False, |
| | ) |
| |
|
| | with gr.Column(scale=3): |
| | image_display = gr.Image( |
| | label="Image", type="filepath", height=700, container=True |
| | ) |
| | with gr.Row(): |
| | upload_button = gr.UploadButton( |
| | "📎 Upload X-Ray", |
| | file_types=["image"], |
| | ) |
| | dicom_upload = gr.UploadButton( |
| | "📄 Upload DICOM", |
| | file_types=["file"], |
| | ) |
| | with gr.Row(): |
| | clear_btn = gr.Button("Clear Chat") |
| | new_thread_btn = gr.Button("New Thread") |
| |
|
| | |
| | def clear_chat(): |
| | interface.original_file_path = None |
| | interface.display_file_path = None |
| | return [], None |
| |
|
| | def new_thread(): |
| | interface.current_thread_id = str(time.time()) |
| | return [], interface.display_file_path |
| |
|
| | def handle_file_upload(file): |
| | return interface.handle_upload(file.name) |
| |
|
| | chat_msg = txt.submit( |
| | interface.add_message, inputs=[txt, image_display, chatbot], outputs=[chatbot, txt] |
| | ) |
| | bot_msg = chat_msg.then( |
| | interface.process_message, |
| | inputs=[txt, image_display, chatbot], |
| | outputs=[chatbot, image_display, txt], |
| | ) |
| | bot_msg.then(lambda: gr.Textbox(interactive=True), None, [txt]) |
| |
|
| | upload_button.upload(handle_file_upload, inputs=upload_button, outputs=image_display) |
| |
|
| | dicom_upload.upload(handle_file_upload, inputs=dicom_upload, outputs=image_display) |
| |
|
| | clear_btn.click(clear_chat, outputs=[chatbot, image_display]) |
| | new_thread_btn.click(new_thread, outputs=[chatbot, image_display]) |
| |
|
| | return demo |
| |
|