| | import rasterio as rio |
| | import pathlib |
| | import opensr_test |
| | import matplotlib.pyplot as plt |
| |
|
| | from typing import Callable, Union |
| |
|
| |
|
| | def create_geotiff( |
| | model: Callable, |
| | fn: Callable, |
| | datasets: Union[str, list], |
| | output_path: str, |
| | force: bool = False, |
| | **kwargs |
| | ) -> None: |
| | """Create all the GeoTIFFs for a specific dataset snippet |
| | |
| | Args: |
| | model (Callable): The model to use to run the fn function. |
| | fn (Callable): A function that return a dictionary with the following keys: |
| | - "lr": Low resolution image |
| | - "sr": Super resolution image |
| | - "hr": High resolution image |
| | datasets (list): A list of dataset snippets to use to run the fn function. |
| | output_path (str): The output path to save the GeoTIFFs. |
| | force (bool, optional): If True, the dataset is redownloaded. Defaults |
| | to False. |
| | """ |
| | |
| | if datasets == "all": |
| | datasets = opensr_test.datasets |
| |
|
| | for snippet in datasets: |
| | create_geotiff_batch( |
| | model=model, |
| | fn=fn, |
| | snippet=snippet, |
| | output_path=output_path, |
| | force=force, |
| | **kwargs |
| | ) |
| |
|
| | return None |
| |
|
| | def create_geotiff_batch( |
| | model: Callable, |
| | fn: Callable, |
| | snippet: str, |
| | output_path: str, |
| | force: bool = False, |
| | **kwargs |
| | ) -> pathlib.Path: |
| | """Create all the GeoTIFFs for a specific dataset snippet |
| | |
| | Args: |
| | model (Callable): The model to use to run the fn function. |
| | fn (Callable): A function that return a dictionary with the following keys: |
| | - "lr": Low resolution image |
| | - "sr": Super resolution image |
| | - "hr": High resolution image |
| | snippet (str): The dataset snippet to use to run the fn function. |
| | output_path (str): The output path to save the GeoTIFFs. |
| | force (bool, optional): If True, the dataset is redownloaded. Defaults |
| | to False. |
| | |
| | Returns: |
| | pathlib.Path: The output path where the GeoTIFFs are saved. |
| | """ |
| | |
| | |
| | output_path = pathlib.Path(output_path) / "results" / "SR" |
| | output_path.mkdir(parents=True, exist_ok=True) |
| |
|
| | output_path_dataset_geotiff = output_path / snippet / "geotiff" |
| | output_path_dataset_geotiff.mkdir(parents=True, exist_ok=True) |
| |
|
| | output_path_dataset_png = output_path / snippet / "png" |
| | output_path_dataset_png.mkdir(parents=True, exist_ok=True) |
| |
|
| | |
| | dataset = opensr_test.load(snippet, force=force) |
| | lr_dataset, hr_dataset, metadata = dataset["L2A"], dataset["HRharm"], dataset["metadata"] |
| | for index in range(len(lr_dataset)): |
| | print(f"Processing {index}/{len(lr_dataset)}") |
| |
|
| | |
| | results = fn( |
| | model=model, |
| | lr=lr_dataset[index], |
| | hr=hr_dataset[index], |
| | **kwargs |
| | ) |
| |
|
| | |
| | image_name = metadata.iloc[index]["hr_file"] |
| |
|
| | |
| | crs = metadata.iloc[index]["crs"] |
| | transform_str = metadata.iloc[index]["affine"] |
| | transform_list = [float(x) for x in transform_str.split(",")] |
| | transform_rio = rio.transform.from_origin( |
| | transform_list[2], |
| | transform_list[5], |
| | transform_list[0], |
| | transform_list[4] * -1 |
| | ) |
| |
|
| | |
| | meta_img = { |
| | "driver": "GTiff", |
| | "count": 3, |
| | "dtype": "uint16", |
| | "height": results["hr"].shape[1], |
| | "width": results["hr"].shape[2], |
| | "crs": crs, |
| | "transform": transform_rio, |
| | "compress": "deflate", |
| | "predictor": 2, |
| | "tiled": True |
| | } |
| |
|
| | |
| | with rio.open(output_path_dataset_geotiff / (image_name + ".tif"), "w", **meta_img) as dst: |
| | dst.write(results["sr"]) |
| |
|
| | |
| | fig, ax = plt.subplots(1, 3, figsize=(15, 5)) |
| | ax[0].imshow((results["lr"].transpose(1, 2, 0) / 3000).clip(0, 1)) |
| | ax[0].set_title("LR") |
| | ax[0].axis("off") |
| | ax[1].imshow((results["sr"].transpose(1, 2, 0) / 3000).clip(0, 1)) |
| | ax[1].set_title("SR") |
| | ax[1].axis("off") |
| | ax[2].imshow((results["hr"].transpose(1, 2, 0) / 3000).clip(0, 1)) |
| | ax[2].set_title("HR") |
| | |
| | plt.subplots_adjust(left=0, right=1, top=1, bottom=0) |
| | plt.axis("off") |
| | plt.savefig(output_path_dataset_png / (image_name + ".png")) |
| | plt.close() |
| | plt.clf() |
| |
|
| | return output_path_dataset_geotiff |
| |
|
| |
|
| |
|
| |
|
| | def run( |
| | model_path: str |
| | ) -> pathlib.Path: |
| | """Run the all metrics for a specific model. |
| | |
| | Args: |
| | model_path (str): The path to the model folder. |
| | |
| | Returns: |
| | pathlib.Path: The output path where the metrics are |
| | saved as a pickle file. |
| | """ |
| | pass |
| |
|
| |
|
| | def plot( |
| | model_path: str |
| | ) -> pathlib.Path: |
| | """Generate the plots and tables for a specific model. |
| | |
| | Args: |
| | model_path (str): The path to the model folder. |
| | |
| | Returns: |
| | pathlib.Path: The output path where the plots and tables are |
| | saved. |
| | """ |
| | pass |
| |
|
| |
|
| |
|