| | import torch |
| | import numpy as np |
| | import gradio as gr |
| |
|
| | from stldm import InferenceHub |
| | from stldm.config import STLDM_HKO |
| | from utilspp import resize, gradio_gif, gradio_visualize |
| |
|
| | def nowcasting(file, cfg_str, ensemble_no): |
| | |
| | Forecastor = InferenceHub( |
| | model_config=STLDM_HKO, |
| | cfg_str=cfg_str if cfg_str > 0 else None, |
| | model_type='HF' |
| | ) |
| |
|
| | |
| | x = torch.tensor(np.load(file.name)) |
| | if x.ndim not in (5, 4): |
| | raise ValueError("Please specify the input has the format of (T C H W)") |
| | |
| | if x.max() > 1: |
| | x = x / 255.0 |
| | x = x.clamp(0, 1) |
| | if x.ndim == 4: |
| | x = x.unsqueeze(0) |
| | x = resize(x, 128) |
| | |
| | if x.shape[1] < 5: |
| | raise ValueError("The input should have at least 5 frames for STLDM to predict") |
| | x = x[0, -5:] |
| | |
| | y_pred, mu = Forecastor(input_x=x, include_mu=True) |
| | out = {'Deterministic': mu, 'Ensemble 1': y_pred} |
| | for i in range(1, ensemble_no): |
| | y_pred = Forecastor(input_x=x, include_mu=False) |
| | out[f'Ensemble {i+1}'] = y_pred |
| |
|
| | past_frames = gradio_visualize(x) |
| | figure = gradio_gif(out, len(out['Ensemble 1'])) |
| |
|
| | return past_frames, figure |
| |
|
| |
|
| |
|
| | with gr.Blocks() as demo: |
| | gr.Markdown("# STLDM Official Demo for **HKO-7** Nowcasting") |
| | gr.Markdown("Please upload the radar sequences with **at least 5 frames** in the format of .npy file, and **STLDM** will predict the future 20 frames based on the past 5 frames.") |
| | gr.Markdown('**Paper** - [STLDM: Spatio-Temporal Latent Diffusion Model for Precipitation Nowcasting](https://arxiv.org/abs/2512.21118)') |
| | gr.Markdown('**Code** - [https://github.com/sqfoo/stldm_official](https://github.com/sqfoo/stldm_official)') |
| |
|
| | gr.Markdown("## Input Frames") |
| | file_input = gr.File(label="Upload the input radar squences", file_types=[".npy"]) |
| | |
| | gr.Markdown("## Parameters") |
| | cfg_str = gr.Slider(0.0, 2.0, value=1.0, step=0.1, label="Classifier Free Guidance Scale") |
| | ensemble_no = gr.Slider(1, 10, value=2, step=1, label="How many ensemble predictions?") |
| |
|
| | gr.Markdown("## Predictions") |
| | input_frames = gr.Image(label="Past 5 frames") |
| | prediction = gr.Image(label="Evolving Predictions") |
| | btn = gr.Button("Forecast Now!") |
| | btn.click(fn=nowcasting, inputs=[file_input, cfg_str, ensemble_no], outputs=[input_frames, prediction]) |
| |
|
| | if __name__ == "__main__": |
| | demo.launch(share=True) |
| |
|
| |
|