| | import plotly.express as px |
| | import plotly.graph_objects as go |
| | import pandas as pd |
| | import numpy as np |
| | import matplotlib.pyplot as plt |
| | from matplotlib.figure import Figure |
| | from matplotlib.backends.backend_agg import FigureCanvasAgg |
| | from io import BytesIO |
| | from PIL import Image |
| | import json |
| | import os |
| | import cartopy.crs as ccrs |
| | import cartopy.feature as cfeature |
| | from matplotlib.figure import Figure |
| | from matplotlib.backends.backend_agg import FigureCanvasAgg |
| |
|
| | def get_background_map_trace(): |
| | |
| | base_dir = os.path.dirname(os.path.abspath(__file__)) |
| | geojson_path = os.path.join(base_dir, 'countries.geo.json') |
| | |
| | if not os.path.exists(geojson_path): |
| | print(f"Warning: Local GeoJSON not found at {geojson_path}") |
| | |
| | print(f"Current Working Directory: {os.getcwd()}") |
| | try: |
| | print(f"Files in {base_dir}: {os.listdir(base_dir)}") |
| | except Exception as e: |
| | print(f"Error listing files: {e}") |
| | return None |
| | |
| | try: |
| | with open(geojson_path, 'r', encoding='utf-8') as f: |
| | world_geojson = json.load(f) |
| | |
| | ids = [f['id'] for f in world_geojson['features'] if 'id' in f] |
| | print(f"DEBUG: Loaded {len(ids)} countries from {geojson_path}") |
| | |
| | if not ids: |
| | print("DEBUG: No IDs found in GeoJSON features") |
| | return None |
| |
|
| | |
| | |
| | bg_trace = go.Choropleth( |
| | geojson=world_geojson, |
| | locations=ids, |
| | z=[1]*len(ids), |
| | colorscale=[[0, 'rgb(243, 243, 243)'], [1, 'rgb(243, 243, 243)']], |
| | showscale=False, |
| | marker_line_color='rgb(204, 204, 204)', |
| | marker_line_width=0.5, |
| | hoverinfo='skip', |
| | name='Background' |
| | ) |
| | return bg_trace |
| | except Exception as e: |
| | print(f"Error loading GeoJSON: {e}") |
| | return None |
| |
|
| |
|
| | def plot_global_map_static(df, lat_col='centre_lat', lon_col='centre_lon'): |
| | if df is None: |
| | return None, None |
| | |
| | |
| | df_clean = df.copy() |
| | df_clean[lat_col] = pd.to_numeric(df_clean[lat_col], errors='coerce') |
| | df_clean[lon_col] = pd.to_numeric(df_clean[lon_col], errors='coerce') |
| | df_clean = df_clean.dropna(subset=[lat_col, lon_col]) |
| | |
| | |
| | if len(df_clean) > 250000: |
| | |
| | step = 2 |
| | |
| | df_vis = df_clean.iloc[::step] |
| | print(f"Sampled {len(df_vis)} points from {len(df_clean)} total points (step={step}) for visualization.") |
| | else: |
| | df_vis = df_clean |
| |
|
| | |
| | |
| | |
| | |
| | fig = Figure(figsize=(10, 5), dpi=300) |
| | ax = fig.add_subplot(111, projection=ccrs.PlateCarree()) |
| |
|
| | |
| | ax.add_feature(cfeature.LAND, facecolor='lightgray', alpha=0.2) |
| | ax.add_feature(cfeature.COASTLINE, linewidth=0.5, alpha=0.5) |
| | |
| | |
| | ax.scatter( |
| | df_vis[lon_col], |
| | df_vis[lat_col], |
| | s=0.2, |
| | c="blue", |
| | marker='o', |
| | edgecolors='none', |
| | |
| | transform=ccrs.PlateCarree(), |
| | label='Samples', |
| | ) |
| |
|
| | |
| | ax.set_extent([-180, 180, -90, 90], crs=ccrs.PlateCarree()) |
| | |
| | |
| | ax.axis('off') |
| | |
| | |
| | |
| | ax.legend(loc='lower left', markerscale=5, frameon=True, facecolor='white', framealpha=0.9) |
| | fig.tight_layout() |
| | |
| | |
| | buf = BytesIO() |
| | fig.savefig(buf, format='png', facecolor='white') |
| | buf.seek(0) |
| | img = Image.open(buf) |
| | |
| | return img, df_vis |
| |
|
| | def plot_geographic_distribution(df, scores, threshold, lat_col='centre_lat', lon_col='centre_lon', title="Search Results"): |
| | if df is None or scores is None: |
| | return None, None |
| |
|
| | df_vis = df.copy() |
| | df_vis['score'] = scores |
| | df_vis = df_vis.sort_values(by='score', ascending=False) |
| | |
| | |
| | top_n = int(len(df_vis) * threshold) |
| | if top_n < 1: top_n = 1 |
| | |
| | df_filtered = df_vis.head(top_n) |
| | |
| | fig = Figure(figsize=(10, 5), dpi=300) |
| | ax = fig.add_subplot(111, projection=ccrs.PlateCarree()) |
| |
|
| | |
| | ax.add_feature(cfeature.LAND, facecolor='lightgray', alpha=0.2) |
| | ax.add_feature(cfeature.COASTLINE, linewidth=0.5, alpha=0.5) |
| | |
| | |
| | label_text = f'Top {threshold * 1000:.0f}‰ Matches' |
| | sc = ax.scatter( |
| | df_filtered[lon_col], |
| | df_filtered[lat_col], |
| | c=df_filtered['score'], |
| | cmap='Reds', |
| | s=0.35, |
| | alpha=0.8, |
| | transform=ccrs.PlateCarree(), |
| | label=label_text, |
| | ) |
| |
|
| | ax.set_extent([-180, 180, -90, 90], crs=ccrs.PlateCarree()) |
| | ax.axis('off') |
| | |
| | |
| | |
| | cbar = fig.colorbar(sc, ax=ax, fraction=0.025, pad=0.02) |
| | cbar.set_label('Similarity Score') |
| | |
| | |
| | ax.legend(loc='lower left', markerscale=3, frameon=True, facecolor='white', framealpha=0.9) |
| | |
| | fig.tight_layout() |
| | |
| | |
| | |
| | |
| | buf = BytesIO() |
| | fig.savefig(buf, format='png', facecolor='white') |
| | buf.seek(0) |
| | img = Image.open(buf) |
| | |
| | return img, df_filtered |
| |
|
| |
|
| | def format_results_for_gallery(results): |
| | """ |
| | Format results for Gradio Gallery. |
| | results: list of dicts |
| | Returns: list of (image, caption) tuples |
| | """ |
| | gallery_items = [] |
| | for res in results: |
| | |
| | img = res.get('image_384') |
| | if img is None: |
| | continue |
| | |
| | caption = f"Score: {res['score']:.4f}\nLat: {res['lat']:.2f}, Lon: {res['lon']:.2f}\nID: {res['id']}" |
| | gallery_items.append((img, caption)) |
| | |
| | return gallery_items |
| |
|
| |
|
| | def plot_top5_overview(query_image, results, query_info="Query"): |
| | """ |
| | Generates a matplotlib figure showing the query image and top retrieved images. |
| | Similar to the visualization in SigLIP_embdding.ipynb. |
| | Uses OO Matplotlib API for thread safety. |
| | """ |
| | top_k = len(results) |
| | if top_k == 0: |
| | return None |
| |
|
| | |
| | |
| | if query_image is None and top_k == 10: |
| | cols = 5 |
| | rows = 2 |
| | fig = Figure(figsize=(4 * cols, 4 * rows)) |
| | canvas = FigureCanvasAgg(fig) |
| | |
| | for i, res in enumerate(results): |
| | |
| | r = i // 5 |
| | c = i % 5 |
| | |
| | |
| | ax = fig.add_subplot(rows, cols, i + 1) |
| | |
| | img_384 = res.get('image_384') |
| | if img_384: |
| | ax.imshow(img_384) |
| | ax.set_title(f"Rank {i+1}\nScore: {res['score']:.4f}\n({res['lat']:.2f}, {res['lon']:.2f})", fontsize=9) |
| | else: |
| | ax.text(0.5, 0.5, "N/A", ha='center', va='center') |
| | ax.axis('off') |
| | |
| | fig.tight_layout() |
| | |
| | buf = BytesIO() |
| | fig.savefig(buf, format='png', bbox_inches='tight') |
| | buf.seek(0) |
| | return Image.open(buf) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | cols = top_k + (1 if query_image else 0) |
| | rows = 2 |
| | |
| | fig = Figure(figsize=(4 * cols, 8)) |
| | canvas = FigureCanvasAgg(fig) |
| | |
| | |
| | if query_image: |
| | |
| | ax = fig.add_subplot(rows, cols, 1) |
| | ax.imshow(query_image) |
| | ax.set_title(f"Query\n{query_info}", color='blue', fontweight='bold') |
| | ax.axis('off') |
| | |
| | |
| | |
| | ax = fig.add_subplot(rows, cols, cols + 1) |
| | ax.axis('off') |
| | |
| | start_col = 2 |
| | else: |
| | start_col = 1 |
| |
|
| | |
| | for i, res in enumerate(results): |
| | |
| | ax1 = fig.add_subplot(rows, cols, start_col + i) |
| | img_384 = res.get('image_384') |
| | if img_384: |
| | ax1.imshow(img_384) |
| | ax1.set_title(f"Rank {i+1} (384)\nScore: {res['score']:.4f}\n({res['lat']:.2f}, {res['lon']:.2f})", fontsize=9) |
| | else: |
| | ax1.text(0.5, 0.5, "N/A", ha='center', va='center') |
| | ax1.axis('off') |
| | |
| | |
| | ax2 = fig.add_subplot(rows, cols, cols + start_col + i) |
| | img_full = res.get('image_full') |
| | if img_full: |
| | ax2.imshow(img_full) |
| | ax2.set_title("Original", fontsize=9) |
| | else: |
| | ax2.text(0.5, 0.5, "N/A", ha='center', va='center') |
| | ax2.axis('off') |
| | |
| | fig.tight_layout() |
| | |
| | |
| | buf = BytesIO() |
| | fig.savefig(buf, format='png', bbox_inches='tight') |
| | buf.seek(0) |
| | |
| | return Image.open(buf) |
| |
|
| | def plot_location_distribution(df_all, query_lat, query_lon, results, query_info="Query"): |
| | """ |
| | Generates a global distribution map for location search. |
| | Reference: improve2_satclip.ipynb |
| | """ |
| | if df_all is None: |
| | return None |
| | |
| | fig = Figure(figsize=(8, 4), dpi=300) |
| | canvas = FigureCanvasAgg(fig) |
| | ax = fig.add_subplot(111, projection=ccrs.PlateCarree()) |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | ax.add_feature(cfeature.LAND, facecolor='lightgray', alpha=0.2) |
| | ax.add_feature(cfeature.COASTLINE, linewidth=0.5, alpha=0.5) |
| | |
| | |
| | |
| | ax.scatter(query_lon, query_lat, c='red', s=150, marker='*', edgecolors='black', zorder=10, label='Input Coordinate') |
| | |
| | |
| | res_lons = [r['lon'] for r in results] |
| | res_lats = [r['lat'] for r in results] |
| | ax.scatter(res_lons, res_lats, c='blue', s=50, marker='x', linewidths=2, label=f'Retrieved Top-{len(results)}') |
| | |
| | |
| | for r in results: |
| | ax.plot([query_lon, r['lon']], [query_lat, r['lat']], 'b--', alpha=0.2) |
| | |
| | ax.legend(loc='upper right') |
| | ax.set_title(f"Location of Top 5 Matched Images ({query_info})") |
| | ax.set_xlabel("Longitude") |
| | ax.set_ylabel("Latitude") |
| | ax.grid(True, alpha=0.2) |
| | |
| | |
| | buf = BytesIO() |
| | fig.savefig(buf, format='png', bbox_inches='tight') |
| | buf.seek(0) |
| | |
| | return Image.open(buf) |
| |
|