Spaces:
Runtime error
Runtime error
| import base64 | |
| import io | |
| import logging | |
| import cv2 | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from PIL import Image | |
| # Set up logging | |
| logger = logging.getLogger(__name__) | |
| def plot_image_prediction(image, predictions, title=None, figsize=(10, 8)): | |
| """ | |
| Plot an image with its predictions. | |
| Args: | |
| image (PIL.Image or str): Image or path to image | |
| predictions (list): List of (label, probability) tuples | |
| title (str, optional): Plot title | |
| figsize (tuple): Figure size | |
| Returns: | |
| matplotlib.figure.Figure: The figure object | |
| """ | |
| try: | |
| # Load image if path is provided | |
| if isinstance(image, str): | |
| img = Image.open(image) | |
| else: | |
| img = image | |
| # Create figure | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize) | |
| # Plot image | |
| ax1.imshow(img) | |
| ax1.set_title("X-ray Image") | |
| ax1.axis("off") | |
| # Plot predictions | |
| if predictions: | |
| # Sort predictions by probability | |
| sorted_pred = sorted(predictions, key=lambda x: x[1], reverse=True) | |
| # Get top 5 predictions | |
| top_n = min(5, len(sorted_pred)) | |
| labels = [pred[0] for pred in sorted_pred[:top_n]] | |
| probs = [pred[1] for pred in sorted_pred[:top_n]] | |
| # Plot horizontal bar chart | |
| y_pos = np.arange(top_n) | |
| ax2.barh(y_pos, probs, align="center") | |
| ax2.set_yticks(y_pos) | |
| ax2.set_yticklabels(labels) | |
| ax2.set_xlabel("Probability") | |
| ax2.set_title("Top Predictions") | |
| ax2.set_xlim(0, 1) | |
| # Annotate probabilities | |
| for i, prob in enumerate(probs): | |
| ax2.text(prob + 0.02, i, f"{prob:.1%}", va="center") | |
| # Set overall title | |
| if title: | |
| fig.suptitle(title, fontsize=16) | |
| fig.tight_layout() | |
| return fig | |
| except Exception as e: | |
| logger.error(f"Error plotting image prediction: {e}") | |
| # Create empty figure if error occurs | |
| fig, ax = plt.subplots(figsize=(8, 6)) | |
| ax.text(0.5, 0.5, f"Error: {str(e)}", ha="center", va="center") | |
| return fig | |
| def create_heatmap_overlay(image, heatmap, alpha=0.4): | |
| """ | |
| Create a heatmap overlay on an X-ray image to highlight areas of interest. | |
| Args: | |
| image (PIL.Image or str): Image or path to image | |
| heatmap (numpy.ndarray): Heatmap array | |
| alpha (float): Transparency of the overlay | |
| Returns: | |
| PIL.Image: Image with heatmap overlay | |
| """ | |
| try: | |
| # Load image if path is provided | |
| if isinstance(image, str): | |
| img = cv2.imread(image) | |
| if img is None: | |
| raise ValueError(f"Could not load image: {image}") | |
| elif isinstance(image, Image.Image): | |
| img = np.array(image) | |
| img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) | |
| else: | |
| img = image | |
| # Ensure image is in BGR format for OpenCV | |
| if len(img.shape) == 2: # Grayscale | |
| img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |
| # Resize heatmap to match image dimensions | |
| heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0])) | |
| # Normalize heatmap (0-1) | |
| heatmap = np.maximum(heatmap, 0) | |
| heatmap = np.minimum(heatmap / np.max(heatmap), 1) | |
| # Apply colormap (jet) to heatmap | |
| heatmap = np.uint8(255 * heatmap) | |
| heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) | |
| # Create overlay | |
| overlay = cv2.addWeighted(img, 1 - alpha, heatmap, alpha, 0) | |
| # Convert back to PIL image | |
| overlay = cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB) | |
| overlay_img = Image.fromarray(overlay) | |
| return overlay_img | |
| except Exception as e: | |
| logger.error(f"Error creating heatmap overlay: {e}") | |
| # Return original image if error occurs | |
| if isinstance(image, str): | |
| return Image.open(image) | |
| elif isinstance(image, Image.Image): | |
| return image | |
| else: | |
| return Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) | |
| def plot_report_entities(text, entities, figsize=(12, 8)): | |
| """ | |
| Visualize entities extracted from a medical report. | |
| Args: | |
| text (str): Report text | |
| entities (dict): Dictionary of entities by category | |
| figsize (tuple): Figure size | |
| Returns: | |
| matplotlib.figure.Figure: The figure object | |
| """ | |
| try: | |
| fig, ax = plt.subplots(figsize=figsize) | |
| ax.axis("off") | |
| # Set background color | |
| fig.patch.set_facecolor("#f8f9fa") | |
| ax.set_facecolor("#f8f9fa") | |
| # Title | |
| ax.text( | |
| 0.5, | |
| 0.98, | |
| "Medical Report Analysis", | |
| ha="center", | |
| va="top", | |
| fontsize=18, | |
| fontweight="bold", | |
| color="#2c3e50", | |
| ) | |
| # Display entity counts | |
| y_pos = 0.9 | |
| ax.text( | |
| 0.05, | |
| y_pos, | |
| "Extracted Entities:", | |
| fontsize=14, | |
| fontweight="bold", | |
| color="#2c3e50", | |
| ) | |
| y_pos -= 0.05 | |
| # Define colors for different entity categories | |
| category_colors = { | |
| "problem": "#e74c3c", # Red | |
| "test": "#3498db", # Blue | |
| "treatment": "#2ecc71", # Green | |
| "anatomy": "#9b59b6", # Purple | |
| } | |
| # Display entities by category | |
| for category, items in entities.items(): | |
| if items: | |
| y_pos -= 0.05 | |
| ax.text( | |
| 0.1, | |
| y_pos, | |
| f"{category.capitalize()}:", | |
| fontsize=12, | |
| fontweight="bold", | |
| ) | |
| y_pos -= 0.05 | |
| ax.text( | |
| 0.15, | |
| y_pos, | |
| ", ".join(items), | |
| wrap=True, | |
| fontsize=11, | |
| color=category_colors.get(category, "black"), | |
| ) | |
| # Add the report text with highlighted entities | |
| y_pos -= 0.1 | |
| ax.text( | |
| 0.05, | |
| y_pos, | |
| "Report Text (with highlighted entities):", | |
| fontsize=14, | |
| fontweight="bold", | |
| color="#2c3e50", | |
| ) | |
| y_pos -= 0.05 | |
| # Get all entities to highlight | |
| all_entities = [] | |
| for category, items in entities.items(): | |
| for item in items: | |
| all_entities.append((item, category)) | |
| # Sort entities by length (longest first to avoid overlap issues) | |
| all_entities.sort(key=lambda x: len(x[0]), reverse=True) | |
| # Highlight entities in text | |
| highlighted_text = text | |
| for entity, category in all_entities: | |
| # Escape regex special characters | |
| entity_escaped = ( | |
| entity.replace("(", r"\(") | |
| .replace(")", r"\)") | |
| .replace("[", r"\[") | |
| .replace("]", r"\]") | |
| ) | |
| # Find entity in text (word boundary) | |
| pattern = r"\b" + entity_escaped + r"\b" | |
| color_code = category_colors.get(category, "black") | |
| replacement = f"\\textcolor{{{color_code}}}{{{entity}}}" | |
| highlighted_text = highlighted_text.replace(entity, replacement) | |
| # Display highlighted text | |
| ax.text(0.05, y_pos, highlighted_text, va="top", fontsize=10, wrap=True) | |
| fig.tight_layout(rect=[0, 0.03, 1, 0.97]) | |
| return fig | |
| except Exception as e: | |
| logger.error(f"Error plotting report entities: {e}") | |
| # Create empty figure if error occurs | |
| fig, ax = plt.subplots(figsize=(8, 6)) | |
| ax.text(0.5, 0.5, f"Error: {str(e)}", ha="center", va="center") | |
| return fig | |
| def plot_multimodal_results( | |
| fused_results, image=None, report_text=None, figsize=(12, 10) | |
| ): | |
| """ | |
| Visualize the results of multimodal analysis. | |
| Args: | |
| fused_results (dict): Results from multimodal fusion | |
| image (PIL.Image or str, optional): Image or path to image | |
| report_text (str, optional): Report text | |
| figsize (tuple): Figure size | |
| Returns: | |
| matplotlib.figure.Figure: The figure object | |
| """ | |
| try: | |
| # Create figure with a grid layout | |
| fig = plt.figure(figsize=figsize) | |
| gs = fig.add_gridspec(2, 2) | |
| # Add title | |
| fig.suptitle( | |
| "Multimodal Medical Analysis Results", | |
| fontsize=18, | |
| fontweight="bold", | |
| y=0.98, | |
| ) | |
| # 1. Overview panel (top left) | |
| ax_overview = fig.add_subplot(gs[0, 0]) | |
| ax_overview.axis("off") | |
| # Get severity info | |
| severity = fused_results.get("severity", {}) | |
| severity_level = severity.get("level", "Unknown") | |
| severity_score = severity.get("score", 0) | |
| # Get primary finding | |
| primary_finding = fused_results.get("primary_finding", "Unknown") | |
| # Get agreement score | |
| agreement = fused_results.get("agreement_score", 0) | |
| # Create overview text | |
| overview_text = [ | |
| "ANALYSIS OVERVIEW", | |
| f"Primary Finding: {primary_finding}", | |
| f"Severity Level: {severity_level} ({severity_score}/4)", | |
| f"Agreement Score: {agreement:.0%}", | |
| ] | |
| # Define severity colors | |
| severity_colors = { | |
| "Normal": "#2ecc71", # Green | |
| "Mild": "#3498db", # Blue | |
| "Moderate": "#f39c12", # Orange | |
| "Severe": "#e74c3c", # Red | |
| "Critical": "#c0392b", # Dark Red | |
| } | |
| # Add overview text to the panel | |
| y_pos = 0.9 | |
| ax_overview.text( | |
| 0.5, | |
| y_pos, | |
| overview_text[0], | |
| fontsize=14, | |
| fontweight="bold", | |
| ha="center", | |
| va="center", | |
| ) | |
| y_pos -= 0.15 | |
| ax_overview.text( | |
| 0.1, y_pos, overview_text[1], fontsize=12, ha="left", va="center" | |
| ) | |
| y_pos -= 0.1 | |
| # Severity with color | |
| severity_color = severity_colors.get(severity_level, "black") | |
| ax_overview.text( | |
| 0.1, y_pos, "Severity Level:", fontsize=12, ha="left", va="center" | |
| ) | |
| ax_overview.text( | |
| 0.4, | |
| y_pos, | |
| severity_level, | |
| fontsize=12, | |
| color=severity_color, | |
| fontweight="bold", | |
| ha="left", | |
| va="center", | |
| ) | |
| ax_overview.text( | |
| 0.6, y_pos, f"({severity_score}/4)", fontsize=10, ha="left", va="center" | |
| ) | |
| y_pos -= 0.1 | |
| # Agreement score with color | |
| agreement_color = ( | |
| "#2ecc71" | |
| if agreement > 0.7 | |
| else "#f39c12" | |
| if agreement > 0.4 | |
| else "#e74c3c" | |
| ) | |
| ax_overview.text( | |
| 0.1, y_pos, "Agreement Score:", fontsize=12, ha="left", va="center" | |
| ) | |
| ax_overview.text( | |
| 0.4, | |
| y_pos, | |
| f"{agreement:.0%}", | |
| fontsize=12, | |
| color=agreement_color, | |
| fontweight="bold", | |
| ha="left", | |
| va="center", | |
| ) | |
| # 2. Findings panel (top right) | |
| ax_findings = fig.add_subplot(gs[0, 1]) | |
| ax_findings.axis("off") | |
| # Get findings | |
| findings = fused_results.get("findings", []) | |
| # Add findings to the panel | |
| y_pos = 0.9 | |
| ax_findings.text( | |
| 0.5, | |
| y_pos, | |
| "KEY FINDINGS", | |
| fontsize=14, | |
| fontweight="bold", | |
| ha="center", | |
| va="center", | |
| ) | |
| y_pos -= 0.1 | |
| if findings: | |
| for i, finding in enumerate(findings[:5]): # Limit to 5 findings | |
| ax_findings.text(0.05, y_pos, "•", fontsize=14, ha="left", va="center") | |
| ax_findings.text( | |
| 0.1, y_pos, finding, fontsize=11, ha="left", va="center", wrap=True | |
| ) | |
| y_pos -= 0.15 | |
| else: | |
| ax_findings.text( | |
| 0.1, | |
| y_pos, | |
| "No specific findings detailed.", | |
| fontsize=11, | |
| ha="left", | |
| va="center", | |
| ) | |
| # 3. Image panel (bottom left) | |
| ax_image = fig.add_subplot(gs[1, 0]) | |
| if image is not None: | |
| # Load image if path is provided | |
| if isinstance(image, str): | |
| img = Image.open(image) | |
| else: | |
| img = image | |
| # Display image | |
| ax_image.imshow(img) | |
| ax_image.set_title("X-ray Image", fontsize=12) | |
| else: | |
| ax_image.text(0.5, 0.5, "No image available", ha="center", va="center") | |
| ax_image.axis("off") | |
| # 4. Recommendation panel (bottom right) | |
| ax_rec = fig.add_subplot(gs[1, 1]) | |
| ax_rec.axis("off") | |
| # Get recommendations | |
| recommendations = fused_results.get("followup_recommendations", []) | |
| # Add recommendations to the panel | |
| y_pos = 0.9 | |
| ax_rec.text( | |
| 0.5, | |
| y_pos, | |
| "RECOMMENDATIONS", | |
| fontsize=14, | |
| fontweight="bold", | |
| ha="center", | |
| va="center", | |
| ) | |
| y_pos -= 0.1 | |
| if recommendations: | |
| for i, rec in enumerate(recommendations): | |
| ax_rec.text(0.05, y_pos, "•", fontsize=14, ha="left", va="center") | |
| ax_rec.text( | |
| 0.1, y_pos, rec, fontsize=11, ha="left", va="center", wrap=True | |
| ) | |
| y_pos -= 0.15 | |
| else: | |
| ax_rec.text( | |
| 0.1, | |
| y_pos, | |
| "No specific recommendations provided.", | |
| fontsize=11, | |
| ha="left", | |
| va="center", | |
| ) | |
| # Add disclaimer | |
| fig.text( | |
| 0.5, | |
| 0.03, | |
| "DISCLAIMER: This analysis is for informational purposes only and should not replace professional medical advice.", | |
| fontsize=9, | |
| style="italic", | |
| ha="center", | |
| ) | |
| fig.tight_layout(rect=[0, 0.05, 1, 0.95]) | |
| return fig | |
| except Exception as e: | |
| logger.error(f"Error plotting multimodal results: {e}") | |
| # Create empty figure if error occurs | |
| fig, ax = plt.subplots(figsize=(8, 6)) | |
| ax.text(0.5, 0.5, f"Error: {str(e)}", ha="center", va="center") | |
| return fig | |
| def figure_to_base64(fig): | |
| """ | |
| Convert matplotlib figure to base64 string. | |
| Args: | |
| fig (matplotlib.figure.Figure): Figure object | |
| Returns: | |
| str: Base64 encoded string | |
| """ | |
| try: | |
| buf = io.BytesIO() | |
| fig.savefig(buf, format="png", bbox_inches="tight") | |
| buf.seek(0) | |
| img_str = base64.b64encode(buf.read()).decode("utf-8") | |
| return img_str | |
| except Exception as e: | |
| logger.error(f"Error converting figure to base64: {e}") | |
| return "" | |