| import numpy as np |
| import triton_python_backend_utils as pb_utils |
| from omnicloudmask import predict_from_array |
| import rasterio |
| from rasterio.io import MemoryFile |
| from rasterio.enums import Resampling |
|
|
| class TritonPythonModel: |
| def initialize(self, args): |
| """ |
| Initialize the model. This function is called once when the model is loaded. |
| """ |
| |
| |
| print('Initialized Cloud Detection model with JP2 input') |
|
|
| def execute(self, requests): |
| """ |
| Process inference requests. |
| """ |
| responses = [] |
| |
| for request in requests: |
| |
| input_tensor = pb_utils.get_input_tensor_by_name(request, "input_jp2_bytes") |
| |
| jp2_bytes_list = input_tensor.as_numpy() |
|
|
| if len(jp2_bytes_list) != 3: |
| |
| error = pb_utils.TritonError(f"Expected 3 JP2 byte strings, received {len(jp2_bytes_list)}") |
| response = pb_utils.InferenceResponse(output_tensors=[], error=error) |
| responses.append(response) |
| continue |
|
|
| |
| red_bytes = jp2_bytes_list[0] |
| green_bytes = jp2_bytes_list[1] |
| nir_bytes = jp2_bytes_list[2] |
|
|
| try: |
| |
| with MemoryFile(red_bytes) as memfile_red: |
| with memfile_red.open() as src_red: |
| red_data = src_red.read(1).astype(np.float32) |
| target_height = src_red.height |
| target_width = src_red.width |
|
|
| with MemoryFile(green_bytes) as memfile_green: |
| with memfile_green.open() as src_green: |
| |
| if src_green.height != target_height or src_green.width != target_width: |
| |
| green_data = src_green.read( |
| 1, |
| out_shape=(1, target_height, target_width), |
| resampling=Resampling.bilinear |
| ).astype(np.float32) |
| else: |
| green_data = src_green.read(1).astype(np.float32) |
|
|
|
|
| with MemoryFile(nir_bytes) as memfile_nir: |
| with memfile_nir.open() as src_nir: |
| |
| nir_data = src_nir.read( |
| 1, |
| out_shape=(1, target_height, target_width), |
| resampling=Resampling.bilinear |
| ).astype(np.float32) |
|
|
| |
| |
| input_array = np.stack([red_data, green_data, nir_data], axis=0) |
|
|
| |
| pred_mask = predict_from_array(input_array) |
|
|
| |
| output_tensor = pb_utils.Tensor( |
| "output_mask", |
| pred_mask.astype(np.uint8) |
| ) |
| response = pb_utils.InferenceResponse([output_tensor]) |
|
|
| except Exception as e: |
| |
| error = pb_utils.TritonError(f"Error processing JP2 data: {str(e)}") |
| response = pb_utils.InferenceResponse(output_tensors=[], error=error) |
|
|
| responses.append(response) |
|
|
| |
| return responses |
|
|
| def finalize(self): |
| """ |
| Called when the model is unloaded. Perform any necessary cleanup. |
| """ |
| print('Finalizing Cloud Detection model') |
|
|