| | import os |
| | import shutil |
| | import tempfile |
| | import gradio as gr |
| | import plotly.graph_objects as go |
| |
|
| | import pandas as pd |
| | from time import time |
| | from utils import ( |
| | create_file_structure, |
| | init_info_csv, |
| | add_to_info_csv, |
| | ) |
| |
|
| | from satseg.dataset import create_datasets, create_inference_dataset |
| | from satseg.model import train_model, save_model, run_inference, load_model |
| | from satseg.seg_result import combine_seg_maps, get_combined_map_contours |
| | from satseg.geo_tools import ( |
| | shapefile_to_latlong, |
| | shapefile_to_grid_indices, |
| | points_to_shapefile, |
| | contours_to_shapefile, |
| | get_tif_n_channels, |
| | ) |
| |
|
| | DATA_DIR = "data" |
| | MODEL_DIR = os.path.join(DATA_DIR, "models") |
| | TIF_DIR = os.path.join(DATA_DIR, "tifs") |
| | MASK_DIR = os.path.join(DATA_DIR, "masks") |
| | INFO_DIR = os.path.join(DATA_DIR, "info") |
| |
|
| | MODEL_INFO_PATH = os.path.join(INFO_DIR, "model_data.csv") |
| | DATASET_TIF_INFO_PATH = os.path.join(INFO_DIR, "dataset_tif_data.csv") |
| | DATASET_MASK_INFO_PATH = os.path.join(INFO_DIR, "dataset_mask_data.csv") |
| |
|
| | create_file_structure( |
| | [DATA_DIR, TIF_DIR, MASK_DIR, INFO_DIR], |
| | [MODEL_INFO_PATH, DATASET_TIF_INFO_PATH, DATASET_MASK_INFO_PATH], |
| | ) |
| | init_info_csv( |
| | MODEL_INFO_PATH, |
| | [ |
| | "Name", |
| | "Architecture", |
| | "# of channels", |
| | "Train TIF", |
| | "Train Mask", |
| | "Expression", |
| | "Path", |
| | ], |
| | ) |
| | init_info_csv(DATASET_TIF_INFO_PATH, ["Name", "# of channels", "Path"]) |
| | init_info_csv(DATASET_MASK_INFO_PATH, ["Name", "Class", "Path"]) |
| |
|
| |
|
| | def gr_train_model( |
| | tif_names, mask_names, model_name, expression, progress=gr.Progress() |
| | ): |
| | tif_paths = list(map(lambda x: os.path.join(TIF_DIR, x), tif_names)) |
| | mask_paths = list(map(lambda x: os.path.join(MASK_DIR, x), mask_names)) |
| | expression = expression.strip().split() |
| |
|
| | |
| | |
| | |
| | progress(0, desc="Creating Dataset...") |
| | with tempfile.TemporaryDirectory() as tempdir: |
| | train_set, val_set = create_datasets( |
| | tif_paths, mask_paths, tempdir, expression=expression |
| | ) |
| | progress(0.05, desc="Training Model...") |
| | model, _ = train_model(train_set, val_set, "unet") |
| |
|
| | progress(0.95, desc="Model Trained! Saving...") |
| | model_name = "_".join(model_name.split()) + ".pt" |
| | model_path = os.path.join(MODEL_DIR, model_name) |
| | save_model(model, model_path) |
| | add_to_info_csv( |
| | MODEL_INFO_PATH, |
| | [ |
| | model_name, |
| | "UNet", |
| | val_set.n_channels, |
| | ";".join(tif_names), |
| | ";".join(mask_names), |
| | " ".join(expression), |
| | model_path, |
| | ], |
| | ) |
| | progress(1.0, desc="Done!") |
| | model_df = pd.read_csv(MODEL_INFO_PATH) |
| |
|
| | return "Done!", model_df, gr.Dropdown.update(choices=model_df["Name"].to_list()) |
| |
|
| |
|
| | def gr_run_inference(tif_names, model_name, progress=gr.Progress()): |
| | t = time() |
| | tif_paths = list(map(lambda x: os.path.join(TIF_DIR, x), tif_names)) |
| | model_df = pd.read_csv(MODEL_INFO_PATH, index_col="Name") |
| | model_path = model_df["Path"][model_name] |
| |
|
| | with tempfile.TemporaryDirectory() as tempdir: |
| | progress(0, desc="Creating Dataset...") |
| | dataset = create_inference_dataset( |
| | tif_paths, |
| | tempdir, |
| | 256, |
| | expression=model_df["Expression"][model_name].split(), |
| | ) |
| | progress(0.1, desc="Loading Model...") |
| | model = load_model(model_path) |
| |
|
| | result_dir = os.path.join(tempdir, "infer") |
| | comb_result_dir = os.path.join(tempdir, "comb") |
| | os.makedirs(result_dir) |
| | os.makedirs(comb_result_dir) |
| | progress(0.2, desc="Running Inference...") |
| | run_inference(dataset, model, result_dir) |
| | progress(0.8, desc="Preparing output...") |
| | combine_seg_maps(result_dir, comb_result_dir) |
| | results = get_combined_map_contours(comb_result_dir) |
| |
|
| | file_paths = [] |
| | out_dir = os.path.join(MASK_DIR, "output") |
| | if os.path.exists(out_dir): |
| | shutil.rmtree(out_dir) |
| | os.makedirs(out_dir) |
| | for tif_name, (contours, hierarchy) in results.items(): |
| | tif_path = os.path.join(TIF_DIR, f"{tif_name}.tif") |
| | mask_path = os.path.join(out_dir, f"{tif_name}_mask.shp") |
| | zip_path = contours_to_shapefile(contours, hierarchy, tif_path, mask_path) |
| | file_paths.append(zip_path) |
| | print(time() - t, "seconds") |
| | return file_paths |
| |
|
| |
|
| | def gr_save_mask_file(file_objs, filenames, obj_class): |
| | print("Saving file(s)...") |
| | idx = 0 |
| | for filename in filenames.split(";"): |
| | if filename.strip() == "": |
| | continue |
| |
|
| | filepath = os.path.join(MASK_DIR, filename.strip()) |
| | obj = file_objs[idx] |
| | idx += 1 |
| |
|
| | shutil.move(obj.name, filepath) |
| | if filename.endswith(".shp"): |
| | add_to_info_csv(DATASET_MASK_INFO_PATH, [filename, obj_class, filepath]) |
| | print("Done!") |
| |
|
| | dataset_df = pd.read_csv(DATASET_MASK_INFO_PATH) |
| | choices = dataset_mask_df["Name"].to_list() |
| | update = gr.Dropdown.update(choices=choices) |
| |
|
| | return dataset_df, update, update |
| |
|
| |
|
| | def gr_save_tif_file(file_objs, filenames): |
| | print("Saving file(s)...") |
| | idx = 0 |
| | for filename in filenames.split(";"): |
| | if filename.strip() == "": |
| | continue |
| |
|
| | filepath = os.path.join(TIF_DIR, filename.strip()) |
| | obj = file_objs[idx] |
| | idx += 1 |
| |
|
| | shutil.copy2(obj.name, filepath) |
| | n = get_tif_n_channels(filepath) |
| | add_to_info_csv(DATASET_TIF_INFO_PATH, [filename, n, filepath]) |
| | print("Done!") |
| |
|
| | dataset_df = pd.read_csv(DATASET_TIF_INFO_PATH) |
| | choices = dataset_mask_df["Name"].to_list() |
| | update = gr.Dropdown.update(choices=choices) |
| |
|
| | return dataset_df, update, update |
| |
|
| |
|
| | def gr_generate_map(mask_name: str, token: str = "", show_grid=True, show_mask=False): |
| | mask_path = os.path.join(MASK_DIR, mask_name) |
| | |
| | center = (7.753769, 80.691730) |
| |
|
| | scattermaps = [] |
| | if show_grid: |
| | indices = shapefile_to_grid_indices(mask_path) |
| | points_to_shapefile(indices, mask_path[: -len(".shp")] + "-grid.shp") |
| | scattermaps.append( |
| | go.Scattermapbox( |
| | lat=indices[:, 1], |
| | lon=indices[:, 0], |
| | mode="markers", |
| | marker=go.scattermapbox.Marker(size=6), |
| | ) |
| | ) |
| | if show_mask: |
| | contours = shapefile_to_latlong(mask_path) |
| | for contour in contours[38:39]: |
| | lons = contour[:, 0] |
| | lats = contour[:, 1] |
| | scattermaps.append( |
| | go.Scattermapbox( |
| | fill="toself", |
| | lat=lats, |
| | lon=lons, |
| | mode="markers", |
| | marker=go.scattermapbox.Marker(size=6), |
| | ) |
| | ) |
| |
|
| | fig = go.Figure(scattermaps) |
| |
|
| | if token: |
| | fig.update_layout( |
| | mapbox=dict( |
| | style="satellite-streets", |
| | accesstoken=token, |
| | center=go.layout.mapbox.Center(lat=center[0], lon=center[1]), |
| | pitch=0, |
| | zoom=7, |
| | ), |
| | mapbox_layers=[ |
| | { |
| | |
| | "sourcetype": "raster", |
| | "sourceattribution": "United States Geological Survey", |
| | "source": [ |
| | "https://basemap.nationalmap.gov/arcgis/rest/services/USGSImageryOnly/MapServer/tile/{z}/{y}/{x}" |
| | ], |
| | } |
| | ], |
| | ) |
| | else: |
| | fig.update_layout( |
| | mapbox_style="open-street-map", |
| | hovermode="closest", |
| | mapbox=dict( |
| | bearing=0, |
| | center=go.layout.mapbox.Center(lat=center[0], lon=center[1]), |
| | pitch=0, |
| | zoom=7, |
| | ), |
| | ) |
| |
|
| | return fig |
| |
|
| |
|
| | with gr.Blocks() as demo: |
| | gr.Markdown( |
| | """# SatSeg |
| | Train models and run inference for segmentation of multispectral satellite images.""" |
| | ) |
| |
|
| | model_df = pd.read_csv(MODEL_INFO_PATH) |
| | dataset_tif_df = pd.read_csv(DATASET_TIF_INFO_PATH) |
| | dataset_mask_df = pd.read_csv(DATASET_MASK_INFO_PATH) |
| |
|
| | with gr.Tab("Train"): |
| | train_tif_names = gr.Dropdown( |
| | label="TIF Files", |
| | choices=dataset_tif_df["Name"].to_list(), |
| | multiselect=True, |
| | ) |
| | train_mask_names = gr.Dropdown( |
| | label="Mask files", |
| | choices=dataset_mask_df["Name"].to_list(), |
| | multiselect=True, |
| | ) |
| | train_rs_index = gr.Textbox( |
| | label="Remote Sensing Index", placeholder="( c0 + c1 ) / ( c0 - c1 ) =" |
| | ) |
| | |
| | |
| | |
| | train_model_name = gr.Textbox( |
| | label="Model Name", placeholder="Give the model a name" |
| | ) |
| | train_button = gr.Button("Train") |
| |
|
| | train_completion = gr.Text(label="Training Status", value="Not Started") |
| |
|
| | with gr.Tab("Infer"): |
| | infer_tif_names = gr.Dropdown( |
| | label="TIF Files", |
| | choices=dataset_tif_df["Name"].to_list(), |
| | multiselect=True, |
| | ) |
| | infer_model_name = gr.Dropdown( |
| | label="Model Name", |
| | choices=model_df["Name"].to_list(), |
| | ) |
| | infer_button = gr.Button("Infer") |
| |
|
| | infer_mask = gr.Files(label="Output Shapefile", interactive=False) |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| |
|
| | |
| |
|
| | with gr.Tab("Datasets"): |
| | dataset_tif_df = pd.read_csv(DATASET_TIF_INFO_PATH) |
| | dataset_mask_df = pd.read_csv(DATASET_MASK_INFO_PATH) |
| |
|
| | datasets_upload_tif = gr.File(label="Images (.tif)", file_count="multiple") |
| | datasets_upload_tif_name = gr.Textbox( |
| | label="TIF name", placeholder="tif_file_1.tif;tif_file_2.tif" |
| | ) |
| | datasets_save_uploaded_tif = gr.Button("Save") |
| |
|
| | datasets_upload_mask = gr.File( |
| | label="Masks (Please upload all extensions (.shp, .shx, etc.))", |
| | file_count="multiple", |
| | ) |
| | datasets_upload_mask_name = gr.Textbox( |
| | label="Mask name", placeholder="mask_1.shp;mask_1.shx" |
| | ) |
| | datasets_mask_class_name = gr.Textbox( |
| | label="Class (The name of the object you want to segment)" |
| | ) |
| | datasets_save_uploaded_mask = gr.Button("Save") |
| |
|
| | datasets_tif_table = gr.Dataframe(dataset_tif_df, label="TIFs") |
| | datasets_mask_table = gr.Dataframe(dataset_mask_df, label="Masks") |
| |
|
| | with gr.Tab("Models"): |
| | models_table = gr.Dataframe(model_df) |
| |
|
| | train_button.click( |
| | gr_train_model, |
| | inputs=[ |
| | train_tif_names, |
| | train_mask_names, |
| | |
| | train_model_name, |
| | train_rs_index, |
| | ], |
| | outputs=[train_completion, models_table, infer_model_name], |
| | ) |
| |
|
| | infer_button.click( |
| | gr_run_inference, |
| | inputs=[infer_tif_names, infer_model_name], |
| | outputs=[infer_mask], |
| | ) |
| |
|
| | datasets_upload_tif.upload( |
| | lambda y: ";".join(list(map(lambda x: os.path.basename(x.orig_name), y))), |
| | inputs=datasets_upload_tif, |
| | outputs=datasets_upload_tif_name, |
| | ) |
| |
|
| | datasets_upload_mask.upload( |
| | lambda y: ";".join(list(map(lambda x: os.path.basename(x.orig_name), y))), |
| | inputs=datasets_upload_mask, |
| | outputs=datasets_upload_mask_name, |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | datasets_save_uploaded_tif.click( |
| | gr_save_tif_file, |
| | inputs=[datasets_upload_tif, datasets_upload_tif_name], |
| | outputs=[datasets_tif_table, train_tif_names, infer_tif_names], |
| | ) |
| | datasets_save_uploaded_mask.click( |
| | gr_save_mask_file, |
| | inputs=[ |
| | datasets_upload_mask, |
| | datasets_upload_mask_name, |
| | datasets_mask_class_name, |
| | ], |
| | outputs=[datasets_mask_table, train_mask_names], |
| | ) |
| |
|
| | demo.queue(concurrency_count=10).launch(debug=True) |
| |
|