| |
| """Untitled2.ipynb |
| |
| Automatically generated by Colab. |
| |
| Original file is located at |
| https://colab.research.google.com/drive/1UcsSFSmZqIdAQTsD_4_CmwwAcAzz0h60 |
| """ |
|
|
| import numpy as np |
| import pandas as pd |
| import os |
| import io |
| import matplotlib.pyplot as plt |
| import matplotlib.cm as cm |
| import folium |
| import matplotlib.colors |
| from scipy.stats import gaussian_kde |
| from PIL import Image |
| import gradio as gr |
| import huggingface_hub |
| from huggingface_hub import HfApi, hf_hub_download, create_repo, file_exists, upload_file |
| import tempfile |
| import pathlib |
| import json |
| import uuid |
| import shutil |
| import zipfile |
| from datetime import datetime |
|
|
| |
| |
| try: |
| from autogluon.multimodal import MultiModalPredictor |
| AUTOGLUON_IMPORTED = True |
| except ImportError: |
| |
| AUTOGLUON_IMPORTED = False |
| class MultiModalPredictor: |
| @staticmethod |
| def load(path): |
| raise ImportError("AutoGluon MultiModalPredictor is not installed or failed to import.") |
|
|
| |
| MODEL_REPO_ID = "ddecosmo/lanternfly_classifier" |
| ZIP_FILENAME = "autogluon_image_predictor_dir.zip" |
| MODEL_DIR_NAME = "autogluon_predictor_extracted" |
| CLASSIFICATION_LABELS = ["Lanternfly", "Other Insect", "Neither"] |
|
|
| PREDICTOR = None |
| MODEL_STATUS = "Attempting to load model..." |
|
|
| |
| def _prepare_predictor_dir(repo_id, zip_filename, extract_dir_name) -> str: |
| """Downloads the zipped model and extracts it to a clean directory.""" |
| base_extract_dir = os.path.join(os.getcwd(), extract_dir_name) |
| try: |
| |
| zip_path = hf_hub_download(repo_id=repo_id, filename=zip_filename) |
|
|
| |
| if os.path.exists(base_extract_dir): |
| shutil.rmtree(base_extract_dir) |
| temp_extract_dir = os.path.join(os.getcwd(), "temp_ag_extract") |
| os.makedirs(temp_extract_dir, exist_ok=True) |
|
|
| |
| with zipfile.ZipFile(zip_path, 'r') as zip_ref: |
| zip_ref.extractall(temp_extract_dir) |
|
|
| |
| extracted_contents = os.listdir(temp_extract_dir) |
| if len(extracted_contents) == 1 and os.path.isdir(os.path.join(temp_extract_dir, extracted_contents[0])): |
| final_model_dir = os.path.join(temp_extract_dir, extracted_contents[0]) |
| shutil.move(final_model_dir, base_extract_dir) |
| shutil.rmtree(temp_extract_dir) |
| else: |
| os.rename(temp_extract_dir, base_extract_dir) |
|
|
| return base_extract_dir |
| except Exception as e: |
| print(f"Error during model prep: {e}") |
| return "" |
|
|
| |
| if AUTOGLUON_IMPORTED: |
| try: |
| predictor_dir = _prepare_predictor_dir(MODEL_REPO_ID, ZIP_FILENAME, MODEL_DIR_NAME) |
| if predictor_dir: |
| PREDICTOR = MultiModalPredictor.load(predictor_dir) |
| MODEL_STATUS = f"β
Model Active: {MODEL_REPO_ID}" |
| else: |
| MODEL_STATUS = "β Initialization failed during extraction/download." |
| except Exception as e: |
| PREDICTOR = None |
| MODEL_STATUS = f"β Error loading model: {type(e).__name__} (Load Fail)" |
| else: |
| MODEL_STATUS = "β AutoGluon not imported. Classification tab is disabled." |
|
|
| |
| def classify_image(img: Image.Image): |
| """Predicts the class of the input image using the loaded AutoGluon model.""" |
| if PREDICTOR is None: |
| return "MODEL FAILED TO LOAD", 0.0, 0.0, 0.0 |
|
|
| if img is None: |
| return "NO IMAGE PROVIDED", 0.0, 0.0, 0.0 |
|
|
| final_output = [0.0] * len(CLASSIFICATION_LABELS) |
| final_result = "PREDICTION FAILED" |
|
|
| |
| temp_dir = pathlib.Path(tempfile.mkdtemp()) |
| img_path = temp_dir / "input.png" |
| img.save(img_path) |
|
|
| try: |
| df_path = pd.DataFrame({"image": [str(img_path)]}) |
| proba_df = PREDICTOR.predict_proba(df_path, as_pandas=True) |
| scores_dict = proba_df.iloc[0].to_dict() |
|
|
| |
| scores = [float(scores_dict.get(label, 0.0)) |
| for label in CLASSIFICATION_LABELS] |
|
|
| predicted_class_label = max(scores_dict, key=scores_dict.get) |
| final_output = scores |
| final_result = f"Predicted Class: **{predicted_class_label}**" |
|
|
| except Exception as e: |
| final_result = f"CRITICAL PREDICTION FAILURE: {type(e).__name__} - Check AutoGluon dependencies." |
| finally: |
| shutil.rmtree(temp_dir) |
|
|
| return final_result, final_output[0], final_output[1], final_output[2] |
|
|
|
|
| |
| HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HF_TOKEN_SPACE") |
| DATASET_REPO = os.getenv("DATASET_REPO", "rlogh/lanternfly-data") |
| METADATA_PATH = "metadata/entries.jsonl" |
| api = None |
|
|
| if HF_TOKEN and DATASET_REPO: |
| api = HfApi(token=HF_TOKEN) |
| try: |
| |
| create_repo(DATASET_REPO, repo_type="dataset", exist_ok=True, token=HF_TOKEN) |
| GPS_SAVE_STATUS = "β
Dataset saving enabled." |
| except Exception as e: |
| GPS_SAVE_STATUS = f"β οΈ Error creating dataset repo: {e}" |
| api = None |
| else: |
| GPS_SAVE_STATUS = "β οΈ Running in test mode - no HF credentials (dataset saving disabled)." |
|
|
|
|
| def get_gps_js(): |
| """JavaScript function to be injected into Gradio to capture GPS coordinates.""" |
| return """ |
| () => { |
| // Look for the hidden textbox element by its ID |
| const textarea = document.querySelector('#hidden_gps_input textarea'); |
| if (!textarea) return; |
| |
| if (!navigator.geolocation) { |
| textarea.value = JSON.stringify({error: "Geolocation not supported by this browser/device."}); |
| textarea.dispatchEvent(new Event('input', { bubbles: true })); |
| return; |
| } |
| // Request current position |
| navigator.geolocation.getCurrentPosition( |
| function(position) { |
| const data = { |
| latitude: position.coords.latitude, |
| longitude: position.coords.longitude, |
| accuracy: position.coords.accuracy, |
| timestamp: position.timestamp |
| }; |
| // Write JSON string to the hidden textbox and trigger a change event |
| textarea.value = JSON.stringify(data); |
| textarea.dispatchEvent(new Event('input', { bubbles: true })); |
| }, |
| function(err) { |
| textarea.value = JSON.stringify({ error: err.message }); |
| textarea.dispatchEvent(new Event('input', { bubbles: true })); |
| }, |
| { enableHighAccuracy: true, timeout: 10000 } |
| ); |
| } |
| """ |
|
|
| def handle_gps_location(json_str): |
| """Parses the GPS JSON string and updates the Gradio text boxes.""" |
| try: |
| data = json.loads(json_str) |
| if 'error' in data: |
| status_msg = f"β **GPS Error**: {data['error']}" |
| return status_msg, "", "", "", "" |
|
|
| lat = str(data.get('latitude', '')) |
| lon = str(data.get('longitude', '')) |
| accuracy = str(data.get('accuracy', '')) |
| timestamp_ms = data.get('timestamp') |
|
|
| |
| device_ts = "" |
| if timestamp_ms and isinstance(timestamp_ms, (int, float)): |
| device_ts = datetime.fromtimestamp(timestamp_ms / 1000).isoformat() |
|
|
| status_msg = f"β
**GPS Captured**: {lat[:8]}, {lon[:8]} (accuracy: {accuracy}m)" |
| return status_msg, lat, lon, accuracy, device_ts |
|
|
| except Exception as e: |
| status_msg = f"β **Error parsing GPS data**: {str(e)}" |
| return status_msg, "", "", "", "" |
|
|
|
|
| def _save_image_to_repo(pil_img: Image.Image, dest_rel_path: str) -> None: |
| """Uploads a PIL image into the dataset repo via a memory buffer.""" |
| img_bytes = io.BytesIO() |
| pil_img.save(img_bytes, format="JPEG", quality=90) |
| img_bytes.seek(0) |
| upload_file( |
| path_or_fileobj=img_bytes, path_in_repo=dest_rel_path, |
| repo_id=DATASET_REPO, repo_type="dataset", token=HF_TOKEN, |
| commit_message=f"Upload image {dest_rel_path}", |
| ) |
|
|
| def _append_jsonl_in_repo(new_row: dict) -> None: |
| """Appends a new JSON line to the metadata file in the dataset repo.""" |
| buf = io.BytesIO() |
| existing_lines = [] |
|
|
| try: |
| |
| if file_exists(DATASET_REPO, METADATA_PATH, repo_type="dataset", token=HF_TOKEN): |
| local_path = hf_hub_download( |
| repo_id=DATASET_REPO, filename=METADATA_PATH, |
| repo_type="dataset", token=HF_TOKEN |
| ) |
| with open(local_path, "r", encoding="utf-8") as f: |
| existing_lines = f.read().splitlines() |
| except Exception: |
| |
| pass |
|
|
| |
| existing_lines.append(json.dumps(new_row, ensure_ascii=False)) |
| data = "\n".join(existing_lines).encode("utf-8") |
| buf.write(data); buf.seek(0) |
|
|
| |
| upload_file( |
| path_or_fileobj=buf, path_in_repo=METADATA_PATH, |
| repo_id=DATASET_REPO, repo_type="dataset", token=HF_TOKEN, |
| commit_message=f"Append 1 entry at {datetime.now().isoformat()}Z", |
| ) |
|
|
|
|
| def save_to_dataset(image, lat, lon, accuracy_m, device_ts): |
| """Validates data and saves the image and metadata to the Hugging Face dataset.""" |
| try: |
| if image is None: |
| return "β **Error**: No image captured.", "" |
| if not lat or not lon: |
| return "β **Error**: GPS coordinates missing.", "" |
|
|
| |
| if isinstance(image, np.ndarray): |
| image = Image.fromarray(image.astype('uint8')) |
|
|
| |
| if not api: |
| img_id = str(uuid.uuid4()) |
| timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S") |
| row = {"id": img_id, "image": f"test_{timestamp_str}_{img_id[:8]}.jpg", |
| "latitude": float(lat), "longitude": float(lon), |
| "mode": "test"} |
| status = f"π **Test Mode**: Data validated successfully! Sample {img_id[:8]}" |
| preview = json.dumps(row, indent=2) |
| return status, preview |
|
|
| |
| sample_id = str(uuid.uuid4()) |
| timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S") |
| image_rel_path = f"images/lanternfly_{timestamp_str}_{sample_id[:8]}.jpg" |
|
|
| |
| _save_image_to_repo(image, image_rel_path) |
| server_ts_utc = datetime.now().isoformat() + "Z" |
|
|
| |
| row = { |
| "id": sample_id, "image": image_rel_path, |
| "latitude": float(lat), "longitude": float(lon), |
| "accuracy_m": float(accuracy_m) if accuracy_m else None, |
| "device_timestamp": device_ts if device_ts else None, |
| "server_timestamp_utc": server_ts_utc, |
| } |
| _append_jsonl_in_repo(row) |
|
|
| status = f"β
**Success!** Saved to dataset! Image: `{image_rel_path}`" |
| preview = json.dumps(row, indent=2) |
| return status, preview |
|
|
| except Exception as e: |
| error_msg = f"β **Error during save**: {str(e)}" |
| return error_msg, "" |
|
|
| |
| HUGGINGFACE_DATA_REPO = "rlogh/lanternfly-data" |
| METADATA_PATH = "metadata/entries.jsonl" |
|
|
| |
| pittsburgh_lat_min, pittsburgh_lat_max = 40.3, 40.6 |
| pittsburgh_lon_min, pittsburgh_lon_max = -80.2, -79.8 |
|
|
|
|
| def load_lanternfly_data_from_hf(): |
| """Downloads the JSONL metadata file from HF and extracts latitude/longitude.""" |
| try: |
| |
| local_path = hf_hub_download( |
| repo_id=HUGGINGFACE_DATA_REPO, |
| filename=METADATA_PATH, |
| repo_type="dataset" |
| ) |
|
|
| latitudes = [] |
| longitudes = [] |
|
|
| |
| with open(local_path, 'r', encoding='utf-8') as f: |
| for line in f: |
| try: |
| data = json.loads(line) |
| lat = data.get('latitude') |
| lon = data.get('longitude') |
|
|
| if isinstance(lat, (float, int)) and isinstance(lon, (float, int)): |
| |
| if pittsburgh_lat_min <= lat <= pittsburgh_lat_max and \ |
| pittsburgh_lon_min <= lon <= pittsburgh_lon_max: |
| latitudes.append(lat) |
| longitudes.append(lon) |
|
|
| except json.JSONDecodeError: |
| continue |
|
|
| if not latitudes: |
| return None, None, "Error: Found no valid coordinates in the dataset." |
|
|
| return np.array(latitudes), np.array(longitudes), None |
|
|
| except Exception as e: |
| return None, None, f"Error downloading or parsing HF data: {type(e).__name__} - {e}" |
|
|
|
|
| def calculate_kde_and_points(): |
| """Loads data, calculates KDE, and prepares data for visualization.""" |
| latitudes, longitudes, error = load_lanternfly_data_from_hf() |
|
|
| if error: |
| return None, None, None, error |
|
|
| try: |
| |
| coordinates = np.vstack([longitudes, latitudes]) |
|
|
| |
| kde_object = gaussian_kde(coordinates) |
|
|
| return latitudes, longitudes, kde_object, None |
|
|
| except Exception as e: |
| return None, None, None, f"Error calculating KDE: {type(e).__name__} - {e}" |
|
|
|
|
| def plot_kde_and_points(min_lat, max_lat, min_lon, max_lon, original_latitudes, original_longitudes, kde_object): |
| """Generates an interactive Folium map with points colored by KDE density.""" |
| |
|
|
| |
| original_coordinates = np.vstack([original_longitudes, original_latitudes]) |
| density_at_original_points = kde_object(original_coordinates) |
| |
| density_normalized = (density_at_original_points - density_at_original_points.min()) / (density_at_original_points.max() - density_at_original_points.min() + 1e-9) |
|
|
| |
| colormap = cm.get_cmap('viridis') |
| map_center_lat = np.mean(original_latitudes) |
| map_center_lon = np.mean(original_longitudes) |
| m_colored_points = folium.Map(location=[map_center_lat, map_center_lon], zoom_start=12) |
|
|
| |
| for lat, lon, density_norm in zip(original_latitudes, original_longitudes, density_normalized): |
| color = matplotlib.colors.rgb2hex(colormap(density_norm)) |
| folium.CircleMarker( |
| location=[lat, lon], radius=5, color=color, fill=True, fill_color=color, fill_opacity=0.7, |
| tooltip=f"Lat: {lat:.5f}, Lon: {lon:.5f}" |
| ).add_to(m_colored_points) |
|
|
| colored_points_map_html = m_colored_points._repr_html_() |
|
|
| |
| |
| return None, colored_points_map_html |
|
|
|
|
| def update_visualization_live(): |
| """Main visualization function for the Gradio interface.""" |
| latitudes, longitudes, kde_object, error = calculate_kde_and_points() |
|
|
| if error: |
| |
| return None, f"<h1>{error}</h1>", f"Error: {error}" |
|
|
| |
| pil_image, colored_points_map_html = plot_kde_and_points( |
| pittsburgh_lat_min, pittsburgh_lat_max, pittsburgh_lon_min, pittsburgh_lon_max, |
| latitudes, longitudes, kde_object |
| ) |
|
|
| |
| return pil_image, colored_points_map_html, "" |
|
|
| |
|
|
| with gr.Blocks(title="Lanternfly Tracking Tool") as app: |
|
|
| gr.Markdown("# Lanternfly Tracking Tool") |
|
|
| with gr.Tab("1. Field Capture & Classification"): |
| gr.Markdown(f"## πΈ Lanternfly Classification and GPS Data Capture") |
| gr.Markdown(f"**Model Status**: {MODEL_STATUS}") |
| gr.Markdown(f"**GPS Save Status**: {GPS_SAVE_STATUS}") |
|
|
| with gr.Row(): |
| |
| with gr.Column(scale=1): |
| image_in = gr.Image( |
| type="pil", label="1. Upload or Capture Image", |
| value="https://placehold.co/224x224/ff6347/ffffff?text=Lanternfly", |
| sources=["upload", "webcam"] |
| ) |
| |
| run_classify_btn = gr.Button("π Run Classification", variant="primary", interactive=PREDICTOR is not None) |
|
|
| gr.Markdown("### Classification Result") |
| final_result_box = gr.Textbox(label="Prediction Result", interactive=False) |
| with gr.Row(): |
| conf_0 = gr.Number(label=f"Confidence: {CLASSIFICATION_LABELS[0]}", interactive=False) |
| conf_1 = gr.Number(label=f"Confidence: {CLASSIFICATION_LABELS[1]}", interactive=False) |
| conf_2 = gr.Number(label=f"Confidence: {CLASSIFICATION_LABELS[2]}", interactive=False) |
|
|
|
|
| |
| with gr.Column(scale=1): |
| gr.Markdown("## π GPS Data Capture") |
| gps_btn = gr.Button("π Get GPS", variant="primary") |
| |
| hidden_gps_input = gr.Textbox(visible=False, elem_id="hidden_gps_input") |
|
|
| with gr.Row(): |
| lat_box = gr.Textbox(label="Latitude", interactive=True) |
| lon_box = gr.Textbox(label="Longitude", interactive=True) |
| with gr.Row(): |
| accuracy_box = gr.Textbox(label="Accuracy (m)", interactive=True) |
| device_ts_box = gr.Textbox(label="Device Timestamp", interactive=True) |
|
|
| |
| save_btn = gr.Button("πΎ Save Image & GPS to Dataset", variant="secondary", interactive=api is not None) |
|
|
| gr.Markdown("### Save Status & Preview") |
| gps_status = gr.Markdown("π **Ready for GPS capture and saving.**") |
| preview = gr.JSON(label="Preview JSON") |
|
|
| |
| if PREDICTOR is not None: |
| run_classify_btn.click( |
| fn=classify_image, |
| inputs=[image_in], |
| outputs=[final_result_box, conf_0, conf_1, conf_2] |
| ) |
|
|
| |
| gps_btn.click( |
| fn=None, inputs=[], outputs=[], js=get_gps_js() |
| ) |
| hidden_gps_input.change( |
| fn=handle_gps_location, |
| inputs=[hidden_gps_input], |
| outputs=[gps_status, lat_box, lon_box, accuracy_box, device_ts_box] |
| ) |
| save_btn.click( |
| fn=save_to_dataset, |
| inputs=[image_in, lat_box, lon_box, accuracy_box, device_ts_box], |
| outputs=[gps_status, preview] |
| ) |
|
|
|
|
| with gr.Tab("2. Spatial Data Visualization (KDE)"): |
| gr.Markdown("## πΊοΈ Kernel Density Estimation of Lanternfly Sightings") |
| gr.Markdown(f"**Data Source**: {HUGGINGFACE_DATA_REPO} - Automatically loaded from `metadata/entries.jsonl`") |
|
|
| refresh_btn = gr.Button("π Refresh Map from Hugging Face Data", variant="primary") |
| kde_error_box = gr.Textbox(label="Error/Debug Message", visible=False) |
|
|
| with gr.Row(): |
| interactive_map_out = gr.HTML(label="Interactive Points Map Colored by KDE (Folium)") |
|
|
| |
| matplotlib_placeholder = gr.State(value=None) |
|
|
| |
| refresh_btn.click( |
| fn=update_visualization_live, |
| inputs=[], |
| outputs=[matplotlib_placeholder, interactive_map_out, kde_error_box] |
| ) |
|
|
| |
| app.load( |
| fn=update_visualization_live, |
| inputs=[], |
| outputs=[matplotlib_placeholder, interactive_map_out, kde_error_box] |
| ) |
|
|
|
|
| |
| if __name__ == "__main__": |
| app.launch() |