sqfoo commited on
Commit
be9a67e
·
1 Parent(s): 805021f

Add Application file

Browse files
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 sqfoo
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
app.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import gradio as gr
4
+
5
+ from stldm import InferenceHub
6
+ from stldm.config import STLDM_HKO
7
+ from data.dutils import resize
8
+ from utilspp import gradio_visualize, gradio_gif
9
+
10
+ def nowcasting(file, cfg_str, ensemble_no):
11
+ # Model Setup
12
+ Forecastor = InferenceHub(
13
+ model_config=STLDM_HKO,
14
+ cfg_str=cfg_str,
15
+ model_type='HF'
16
+ )
17
+
18
+ # Data Preparation
19
+ x = torch.tensor(np.load(file.name))
20
+ if x.ndim not in (5, 4):
21
+ raise ValueError("Please specify the input has the format of (T C H W)")
22
+
23
+ if x.max() > 1:
24
+ x = x / 255.0
25
+ x = x.clamp(0, 1)
26
+ if x.ndim == 4:
27
+ x = x.unsqueeze(0)
28
+ x = resize(x, 128) # resize the data to 128 x 128
29
+
30
+ if x.shape[1] < 5:
31
+ raise ValueError("The input should have at least 5 frames for STLDM to predict")
32
+ x = x[0, -5:]
33
+
34
+ out = {}
35
+ for i in range(ensemble_no):
36
+ y_pred = Forecastor(input_x=x, include_mu=False)
37
+ out[f'Ensemble {i+1}'] = torch.cat((x, y_pred), dim=0)
38
+
39
+ figure = gradio_gif(out, len(out['Ensemble 1']))
40
+
41
+ return figure
42
+
43
+ with gr.Blocks() as demo:
44
+ gr.Markdown("# STLDM official demo for nowcasting")
45
+ gr.Markdown("Please upload the radar sequences with **at least 5 frames** in the format of .npy file, and **STLDM** will predict the future 20 frames based on the past 5 frames.")
46
+ gr.Markdown('Please refer to [paper](https://arxiv.org/abs/2512.21118) and [code](https://github.com/sqfoo/stldm_official) for more details about STLDM.')
47
+
48
+ file_input = gr.File(label="Upload the input radar squences", file_types=[".npy"])
49
+ cfg_str = gr.Slider(0.0, 2.0, value=1.0, step=0.1, label="Classifier Free Guidance Scale")
50
+ ensemble_no = gr.Slider(1, 10, value=2, step=1, label="How many ensemble predictions?")
51
+
52
+ output = gr.Image(label="Nowcasting Results")
53
+ btn = gr.Button("Forecast Now!")
54
+ btn.click(fn=nowcasting, inputs=[file_input, cfg_str, ensemble_no], outputs=output)
55
+
56
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.0.0
2
+ antlr4-python3-runtime==4.9.3
3
+ anyio==4.12.0
4
+ argon2-cffi==25.1.0
5
+ argon2-cffi-bindings==25.1.0
6
+ arrow==1.4.0
7
+ asttokens==3.0.1
8
+ async-lru==2.0.5
9
+ attrs==25.4.0
10
+ babel==2.17.0
11
+ beautifulsoup4==4.14.3
12
+ bleach==6.2.0
13
+ cachetools==5.3.2
14
+ certifi==2023.11.17
15
+ cffi==2.0.0
16
+ charset-normalizer==3.3.2
17
+ comm==0.2.3
18
+ contourpy==1.3.0
19
+ cycler==0.12.1
20
+ debugpy==1.8.17
21
+ decorator==5.2.1
22
+ defusedxml==0.7.1
23
+ einops==0.8.1
24
+ exceptiongroup==1.3.1
25
+ executing==2.2.1
26
+ fastjsonschema==2.21.2
27
+ fonttools==4.45.0
28
+ fqdn==1.5.1
29
+ google-auth==2.23.4
30
+ google-auth-oauthlib==0.4.6
31
+ grpcio==1.59.3
32
+ h11==0.16.0
33
+ h5py==3.7.0
34
+ httpcore==1.0.9
35
+ httpx==0.28.1
36
+ idna==3.4
37
+ imageio==2.33.0
38
+ importlib-metadata==6.8.0
39
+ importlib_resources==6.5.2
40
+ ipykernel==6.31.0
41
+ ipython==8.18.1
42
+ ipywidgets==8.1.8
43
+ isoduration==20.11.0
44
+ jedi==0.19.2
45
+ Jinja2==3.1.6
46
+ joblib==1.3.2
47
+ json5==0.12.1
48
+ jsonpointer==3.0.0
49
+ jsonschema==4.25.1
50
+ jsonschema-specifications==2025.9.1
51
+ jupyter==1.1.1
52
+ jupyter-console==6.6.3
53
+ jupyter-events==0.12.0
54
+ jupyter-lsp==2.3.0
55
+ jupyter_client==8.6.3
56
+ jupyter_core==5.8.1
57
+ jupyter_server==2.17.0
58
+ jupyter_server_terminals==0.5.3
59
+ jupyterlab==4.5.0
60
+ jupyterlab_pygments==0.3.0
61
+ jupyterlab_server==2.28.0
62
+ jupyterlab_widgets==3.0.16
63
+ kiwisolver==1.4.5
64
+ lark==1.3.1
65
+ lpips==0.1.4
66
+ Markdown==3.5.1
67
+ MarkupSafe==2.1.3
68
+ matplotlib==3.9.4
69
+ matplotlib-inline==0.2.1
70
+ mistune==3.1.4
71
+ nbclient==0.10.2
72
+ nbconvert==7.16.6
73
+ nbformat==5.10.4
74
+ nest-asyncio==1.6.0
75
+ networkx==3.2.1
76
+ notebook==7.5.0
77
+ notebook_shim==0.2.4
78
+ numpy==1.24.4
79
+ oauthlib==3.2.2
80
+ omegaconf==2.3.0
81
+ opencv-python==4.8.0.74
82
+ overrides==7.7.0
83
+ packaging==23.2
84
+ pandas==1.4.3
85
+ pandocfilters==1.5.1
86
+ parso==0.8.5
87
+ pexpect==4.9.0
88
+ Pillow==10.1.0
89
+ platformdirs==4.4.0
90
+ prometheus_client==0.23.1
91
+ prompt_toolkit==3.0.52
92
+ protobuf==3.19.6
93
+ psutil==7.1.3
94
+ ptyprocess==0.7.0
95
+ pure_eval==0.2.3
96
+ pyasn1==0.5.1
97
+ pyasn1-modules==0.3.0
98
+ pycparser==2.23
99
+ Pygments==2.19.2
100
+ pyparsing==3.1.1
101
+ python-dateutil==2.8.2
102
+ python-json-logger==4.0.0
103
+ pytz==2023.3.post1
104
+ PyWavelets==1.5.0
105
+ PyYAML==6.0
106
+ pyzmq==27.1.0
107
+ referencing==0.36.2
108
+ requests==2.31.0
109
+ requests-oauthlib==1.3.1
110
+ rfc3339-validator==0.1.4
111
+ rfc3986-validator==0.1.1
112
+ rfc3987-syntax==1.1.0
113
+ rpds-py==0.27.1
114
+ rsa==4.9
115
+ SciencePlots==2.2.0
116
+ scikit-image==0.19.3
117
+ scikit-learn==1.1.2
118
+ scipy==1.9.1
119
+ Send2Trash==1.8.3
120
+ six==1.16.0
121
+ soupsieve==2.8
122
+ stack-data==0.6.3
123
+ tensorboard==2.9.0
124
+ tensorboard-data-server==0.6.1
125
+ tensorboard-plugin-wit==1.8.1
126
+ terminado==0.18.1
127
+ threadpoolctl==3.2.0
128
+ tifffile==2023.9.26
129
+ tinycss2==1.4.0
130
+ tomli==2.3.0
131
+ torch==1.12.1+cu116
132
+ torchmetrics==0.11.0
133
+ torchvision==0.13.1+cu116
134
+ tornado==6.5.2
135
+ tqdm==4.66.1
136
+ traitlets==5.14.3
137
+ typing_extensions==4.8.0
138
+ tzdata==2025.2
139
+ uri-template==1.3.0
140
+ urllib3==2.1.0
141
+ wcwidth==0.2.14
142
+ webcolors==24.11.1
143
+ webencodings==0.5.1
144
+ websocket-client==1.9.0
145
+ Werkzeug==3.0.1
146
+ widgetsnbextension==4.0.15
147
+ zipp==3.17.0
stldm/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from stldm.stldm import model_setup
2
+ from stldm.stldm_spatial import model_setup as spatial_setup
3
+ from stldm.inference import InferenceHub
4
+
5
+ n2n_setup = {'2D': spatial_setup, '3D': model_setup}
stldm/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (343 Bytes). View file
 
stldm/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (1.15 kB). View file
 
stldm/__pycache__/config.cpython-38.pyc ADDED
Binary file (1.08 kB). View file
 
stldm/__pycache__/inference.cpython-38.pyc ADDED
Binary file (3.26 kB). View file
 
stldm/__pycache__/modules.cpython-38.pyc ADDED
Binary file (5.45 kB). View file
 
stldm/__pycache__/modules.cpython-39.pyc ADDED
Binary file (5.4 kB). View file
 
stldm/__pycache__/simvpv2.cpython-38.pyc ADDED
Binary file (15.4 kB). View file
 
stldm/__pycache__/simvpv2.cpython-39.pyc ADDED
Binary file (15.2 kB). View file
 
stldm/__pycache__/stldm.cpython-38.pyc ADDED
Binary file (18.4 kB). View file
 
stldm/__pycache__/stldm.cpython-39.pyc ADDED
Binary file (18.4 kB). View file
 
stldm/__pycache__/stldm_hf.cpython-38.pyc ADDED
Binary file (18.6 kB). View file
 
stldm/__pycache__/stldm_spatial.cpython-38.pyc ADDED
Binary file (18.2 kB). View file
 
stldm/__pycache__/stldm_spatial.cpython-39.pyc ADDED
Binary file (18.2 kB). View file
 
stldm/__pycache__/submodules.cpython-38.pyc ADDED
Binary file (15.4 kB). View file
 
stldm/__pycache__/submodules.cpython-39.pyc ADDED
Binary file (15.4 kB). View file
 
stldm/config.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ STLDM_SEVIR = {
2
+ 'model': "stldm",
3
+ 'pre': None,
4
+ 'post': None,
5
+ 'vp_param': {
6
+ 'shape_in': (13, 1, 128, 128),
7
+ 'shape_out': (12, 1, 128, 128),
8
+ 'hid_S': 32,
9
+ 'hid_T': 512,
10
+ 'N_S': 4,
11
+ 'N_T': 8,
12
+ 'groups': 8,
13
+ 'last_activation': 'sigmoid',
14
+ },
15
+ 'stldm_param': {
16
+ 'in_ch': 32,
17
+ 'chs_mult': [1,2,4,8],
18
+ 'num_groups': 8,
19
+ 'heads': 4,
20
+ 'dim_head': 32,
21
+ 'base_ch': 64,
22
+ 'patch_size': 16
23
+ },
24
+ 'param': {
25
+ 'timesteps': 50,
26
+ 'sampling_timesteps': 20,
27
+ 'objective': 'pred_v'
28
+ }
29
+ }
30
+
31
+ STLDM_HKO = {
32
+ 'model': "stldm",
33
+ 'pre': None,
34
+ 'post': None,
35
+ 'vp_param': {
36
+ 'shape_in': (5, 1, 128, 128),
37
+ 'shape_out': (20, 1, 128, 128),
38
+ 'hid_S': 32,
39
+ 'hid_T': 512,
40
+ 'N_S': 4,
41
+ 'N_T': 8,
42
+ 'groups': 8,
43
+ 'last_activation': 'sigmoid',
44
+ },
45
+ 'stldm_param': {
46
+ 'in_ch': 32,
47
+ 'chs_mult': [1,2,4,8],
48
+ 'num_groups': 8,
49
+ 'heads': 4,
50
+ 'dim_head': 32,
51
+ 'base_ch': 64,
52
+ 'patch_size': 16
53
+ },
54
+ 'param': {
55
+ 'timesteps': 50,
56
+ 'sampling_timesteps': 20,
57
+ 'objective': 'pred_v'
58
+ }
59
+ }
60
+
61
+ STLDM_METEO = {
62
+ 'model': "stldm",
63
+ 'pre': None,
64
+ 'post': None,
65
+ 'vp_param': {
66
+ 'shape_in': (5, 1, 128, 128),
67
+ 'shape_out': (20, 1, 128, 128),
68
+ 'hid_S': 32,
69
+ 'hid_T': 512,
70
+ 'N_S': 4,
71
+ 'N_T': 8,
72
+ 'groups': 8,
73
+ 'last_activation': 'sigmoid',
74
+ },
75
+ 'stldm_param': {
76
+ 'in_ch': 32,
77
+ 'chs_mult': [1,2,4,8],
78
+ 'num_groups': 8,
79
+ 'heads': 4,
80
+ 'dim_head': 32,
81
+ 'base_ch': 64,
82
+ 'patch_size': 16
83
+ },
84
+ 'param': {
85
+ 'timesteps': 50,
86
+ 'sampling_timesteps': 20,
87
+ 'objective': 'pred_v'
88
+ }
89
+ }
90
+
91
+
92
+ STLDM_HKO_HF = {
93
+ 'vp_param': {
94
+ 'shape_in': (5, 1, 128, 128),
95
+ 'shape_out': (20, 1, 128, 128),
96
+ 'hid_S': 32,
97
+ 'hid_T': 512,
98
+ 'N_S': 4,
99
+ 'N_T': 8,
100
+ 'groups': 8,
101
+ 'last_activation': 'sigmoid',
102
+ },
103
+ 'stldm_param': {
104
+ 'in_ch': 32,
105
+ 'chs_mult': [1,2,4,8],
106
+ 'num_groups': 8,
107
+ 'heads': 4,
108
+ 'dim_head': 32,
109
+ 'base_ch': 64,
110
+ 'patch_size': 16
111
+ },
112
+ 'timesteps': 50,
113
+ 'sampling_timesteps': 20,
114
+ 'objective': 'pred_v'
115
+ }
stldm/inference.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Tuple
3
+
4
+ from stldm.stldm import model_setup, guidance_scheduler
5
+ from stldm.stldm_spatial import model_setup as spatial_setup
6
+ from stldm.stldm_hf import GaussianDiffusion as hf_setup
7
+
8
+ n2n_setup = {'2D': spatial_setup, '3D': model_setup, 'HF': hf_setup}
9
+
10
+ class InferenceHub:
11
+ """
12
+ Unified inference interface for STLDM models
13
+
14
+ Support local checkpoints and the checkpoint uploaded to Hugging Face.
15
+
16
+ Params:
17
+ - model_config: dict, the model configuration found in "stldm/model_config.py"
18
+ - model_ckpt: str, the path to the model checkpoint. For 'HF' model_type, this can be None.
19
+ - cfg_str: float, the classifier-free guidance strength. If None, no CFG is applied.
20
+ - model_type: str, the type of the model. Options are '2D', '3D', and 'HF'.
21
+ """
22
+
23
+ def __init__(self, model_config, model_ckpt:str=None, cfg_str:float=None, model_type:str='3D', gpu='auto'):
24
+ self.input_size = model_config['vp_param']['shape_in']
25
+ self.sampling_steps = model_config['param']['timesteps']
26
+ self.model_config = self.setup_config(model_config, model_type)
27
+
28
+ self.model = self.setup_model(model_type, self.model_config, model_ckpt)
29
+ self.setup_cfg(cfg_str)
30
+
31
+ if gpu is not None:
32
+ if gpu == 'auto':
33
+ if torch.cuda.device_count() > 0:
34
+ self.model.to(device="cuda")
35
+ else:
36
+ self.model.to(device=f"cuda:{gpu}")
37
+
38
+ def setup_config(self, model_config, model_type):
39
+ if model_type == 'HF':
40
+ HF_config = {
41
+ 'vp_param': model_config['vp_param'],
42
+ 'stldm_param': model_config['stldm_param'],
43
+ **model_config['param'],
44
+ }
45
+ return HF_config
46
+ else:
47
+ return model_config
48
+
49
+ def setup_model(self, model_type, model_config, model_ckpt):
50
+ if model_type not in n2n_setup:
51
+ raise ValueError(f"model_type should be one of {str(list(n2n_setup.keys()))}")
52
+
53
+ if model_type == 'HF':
54
+ model = n2n_setup[model_type](**model_config).from_pretrained("sqfoo/STLDM_official")
55
+ else:
56
+ model = n2n_setup[model_type](model_config)
57
+ model.load_state_dict(torch.load(model_ckpt))
58
+ model.eval()
59
+ return model
60
+
61
+ def setup_cfg(self, cfg_str):
62
+ guidance = guidance_scheduler(sampling_step=self.sampling_steps, const=cfg_str) if cfg_str is not None else None
63
+ self.model.setup_guidance(guidance)
64
+
65
+ """
66
+ This method performs inference on the input tensor.
67
+
68
+ Params:
69
+ - input_x: torch.tensor, the input tensor with shape (B T C H W) or (T C H W)
70
+ - include_mu: bool, whether to return the intermediate representation 'mu' along with the final prediction
71
+ """
72
+ @torch.no_grad()
73
+ def __call__(self, input_x: torch.tensor, include_mu: bool = False):
74
+ ndim = input_x.ndim
75
+ if ndim not in (5, 4):
76
+ raise ValueError("Please specify the input has the either format of (B T C H W) or (T C H W)")
77
+ input_device = input_x.device
78
+
79
+ if ndim == 4:
80
+ input_x = input_x.unsqueeze(0)
81
+
82
+ if input_x.shape[1:] != self.input_size:
83
+ raise ValueError(f"Ensure that the input has the shape of {str(self.input_size)}")
84
+
85
+ input_x = input_x.to(self.model.device)
86
+ if include_mu:
87
+ y_pred, mu = self.model(input_x, includ_mu=include_mu)
88
+ else:
89
+ y_pred = self.model(input_x, includ_mu=include_mu)
90
+ mu = None
91
+
92
+ if mu is not None:
93
+ mu = mu.to(input_device)
94
+ y_pred = y_pred.to(input_device)
95
+
96
+ if ndim == 4:
97
+ y_pred = y_pred[0]
98
+ mu = mu if mu is None else mu[0]
99
+ return (y_pred, mu) if include_mu else y_pred
stldm/modules.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ from stldm.submodules import ChannelConversion
5
+ from stldm.simvpv2 import stride_generator, ConvSC, MidMetaNet
6
+
7
+ class Encoder(nn.Module):
8
+ def __init__(self, C_in, C_hid, N_S):
9
+ super(Encoder, self).__init__()
10
+ strides = stride_generator(N_S)
11
+ self.enc = nn.Sequential(
12
+ ConvSC(C_in, C_hid, stride=strides[0]),
13
+ *[ConvSC(C_hid, C_hid, stride=s) for s in strides[1:]],
14
+ ChannelConversion(C_hid, 2*C_hid)
15
+ )
16
+
17
+ def forward(self, x):
18
+ for encoder in self.enc:
19
+ x = encoder(x)
20
+ (mean, log_var) = torch.chunk(x, 2, dim=1)
21
+ return mean, log_var
22
+
23
+ class Decoder(nn.Module):
24
+ def __init__(self, C_hid, C_out, N_S, last_activation='sigmoid'):
25
+ super(Decoder,self).__init__()
26
+ strides = stride_generator(N_S, reverse=True)
27
+ self.dec = nn.Sequential(
28
+ ChannelConversion(C_hid, C_hid),
29
+ *[ConvSC(C_hid, C_hid, stride=s, transpose=True) for s in strides[:-1]],
30
+ ConvSC(C_hid, C_hid, stride=strides[-1], transpose=True)# Modify HERE
31
+ )
32
+ self.readout = nn.Conv2d(C_hid, C_out, 1)
33
+ if last_activation=='sigmoid':
34
+ self.last = nn.Sigmoid()
35
+ else:
36
+ self.last = nn.Identity()
37
+
38
+ def forward(self, x):
39
+ for decoder in self.dec:
40
+ x = decoder(x)
41
+ Y = self.readout(x)
42
+ return self.last(Y)
43
+
44
+
45
+ class VAE(nn.Module):
46
+ def __init__(self, C_in, hid_S, N_S, last_activation='none'):
47
+ super(VAE, self).__init__()
48
+ self.encoder = Encoder(C_in, hid_S, N_S)
49
+ self.decoder = Decoder(hid_S, C_in, N_S, last_activation)
50
+
51
+ def sample_from_standard_normal(self, mean, log_var):
52
+ std = (0.5 * log_var).exp()
53
+ return mean + std * torch.randn_like(mean)
54
+
55
+ def encode(self, x):
56
+ assert x.ndim==4
57
+ mean, log_var = self.encoder(x)
58
+ return mean, log_var
59
+
60
+ def decode(self, z):
61
+ assert z.ndim==4
62
+ dec = self.decoder(z)
63
+ return dec
64
+
65
+ def kl_from_standard_normal(self, mean, log_var):
66
+ kl = 0.5 * (log_var.exp() + mean.square() - 1.0 - log_var)
67
+ return kl.mean()
68
+
69
+ def _losses_(self, x, y):
70
+ mean, log_var = self.encode(x)
71
+ kl_loss = self.kl_from_standard_normal(mean, log_var)
72
+
73
+ y_pred = self.forward(x)
74
+ recon_loss = nn.MSELoss()(y_pred, y)
75
+ return recon_loss, kl_loss
76
+
77
+ def forward(self, x):
78
+ mu_z, log_var = self.encode(x)
79
+
80
+ z = self.sample_from_standard_normal(mu_z, log_var)
81
+ recon = self.decode(z)
82
+ return recon
83
+
84
+ class SimVPV2_Model(nn.Module):
85
+ def __init__(self, shape_in, shape_out, hid_S=16, hid_T=256, N_S=4, N_T=4,
86
+ mlp_ratio=8., drop=0.0, drop_path=0.0, spatio_kernel_enc=3,
87
+ spatio_kernel_dec=3, last_activation='none', act_inplace=True, **kwargs):
88
+ super(SimVPV2_Model, self).__init__()
89
+ T, C, H, W = shape_in # T is pre_seq_length
90
+ T2, C2, H2, W2 = shape_out # T2 is output length
91
+ assert C==C2 and H==H2 and W==W2, 'Need to be the same image shape for input and output'
92
+ self.T2 = T2
93
+ self.T = T
94
+
95
+ H, W = int(H / 2**(N_S/2)), int(W / 2**(N_S/2)) # downsample 1 / 2**(N_S/2)
96
+
97
+ self.vae = VAE(C_in=C, hid_S=hid_S, N_S=N_S, last_activation=last_activation)
98
+ self.hid = MidMetaNet(T*hid_S, T2*hid_S*2, hid_T, N_T,
99
+ input_resolution=(H, W), model_type='gsta',
100
+ mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path)
101
+
102
+ def forward(self, x_raw):
103
+ B, T, C, H, W = x_raw.shape
104
+ x = x_raw.reshape(B*T, C, H, W)
105
+
106
+ embed, log_var = self.vae.encode(x)
107
+ embed = self.vae.sample_from_standard_normal(embed, log_var)
108
+ *_, C_, H_, W_ = embed.shape
109
+ z = embed.view(B, T, C_, H_, W_)
110
+
111
+ hid, *_ = self.hid(z)
112
+ hid_mu, log_var_hid = torch.chunk(hid, 2, dim=1)
113
+ hid = self.vae.sample_from_standard_normal(hid_mu, log_var_hid)
114
+
115
+ hid = hid.reshape(B*self.T2, C_, H_, W_)
116
+ # conds_ = hid
117
+ conds_ = hid_mu.reshape(B*self.T2, C_, H_, W_)
118
+
119
+ Y = self.vae.decode(hid)
120
+ Y = Y.reshape(B, self.T2, C, H, W)
121
+ return Y, conds_
122
+
123
+ def _losses_(self, x, y):
124
+ y_pred, *_ = self.forward(x)
125
+ recon_loss = nn.MSELoss()(y_pred, y)
126
+ return recon_loss
stldm/simvpv2.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch, math
3
+
4
+ # from torchmodels.simvp import ConvSC, stride_generator
5
+
6
+ class SimVPV2_Model(nn.Module):
7
+ r"""SimVP Model
8
+
9
+ Implementation of `SimVP: Simpler yet Better Video Prediction
10
+ Just Remove The Skip Connection
11
+ <https://arxiv.org/abs/2206.05099>`_.
12
+
13
+ """
14
+ def __init__(self, shape_in, shape_out, hid_S=16, hid_T=256, N_S=4, N_T=4,
15
+ mlp_ratio=8., drop=0.0, drop_path=0.0, spatio_kernel_enc=3,
16
+ spatio_kernel_dec=3, last_activation='none', act_inplace=True, recursive=False, **kwargs):
17
+ super(SimVPV2_Model, self).__init__()
18
+ T, C, H, W = shape_in # T is pre_seq_length
19
+ T2, C2, H2, W2 = shape_out # T2 is output length
20
+ assert C==C2 and H==H2 and W==W2, 'Need to be the same image shape for input and output'
21
+ self.T2 = T2
22
+ self.T = T
23
+
24
+ H, W = int(H / 2**(N_S/2)), int(W / 2**(N_S/2)) # downsample 1 / 2**(N_S/2)
25
+ act_inplace = False
26
+
27
+ self.enc = Encoder(C, hid_S, N_S)#, spatio_kernel_enc, act_inplace=act_inplace)
28
+ self.dec = Decoder(hid_S, C, N_S, last_activation)#, spatio_kernel_dec, act_inplace=act_inplace)
29
+
30
+ # Modify HERE
31
+ self.recursive = recursive
32
+ if not self.recursive:
33
+ self.hid = MidMetaNet(T*hid_S, T2*hid_S, hid_T, N_T,
34
+ input_resolution=(H, W), model_type='gsta',
35
+ mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path)
36
+ else:
37
+ self.hid = MidMetaNet(T*hid_S, T*hid_S, hid_T, N_T,
38
+ input_resolution=(H, W), model_type='gsta',
39
+ mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path)
40
+ self.last_activation = last_activation
41
+
42
+ def forward(self, x_raw, **kwargs):
43
+ B, T, C, H, W = x_raw.shape
44
+ # x = x_raw.view(B*T, C, H, W)
45
+ x = x_raw.reshape(B*T, C, H, W)
46
+
47
+ embed = self.enc(x)
48
+ _, C_, H_, W_ = embed.shape
49
+
50
+ z = embed.view(B, T, C_, H_, W_)
51
+
52
+ if not self.recursive:
53
+ hid, conds_ = self.hid(z)
54
+ else:
55
+ no = self.T2//self.T
56
+ if self.T2%self.T != 0:
57
+ no += 1
58
+ hid = []
59
+ for i in range(no):
60
+ z, _ = self.hid(z)
61
+ hid.append(z)
62
+ hid = torch.cat(hid, dim=1)
63
+ hid = hid[:, :self.T2]
64
+ conds_ = hid.reshape(-1, C_, H_, W_)
65
+ # print(hid.shape, conds_.shape)
66
+
67
+ hid = hid.reshape(B*self.T2, C_, H_, W_)
68
+
69
+ Y = self.dec(hid)
70
+ Y = Y.reshape(B, self.T2, C, H, W)
71
+ return Y, conds_, hid.reshape(B, -1, C_, H_, W_)
72
+
73
+ def recon_loss(self, x, y):
74
+ X = torch.cat((x, y), dim=1)
75
+ B, T, C, H, W = X.shape
76
+ X = X.reshape(-1, C, H, W)
77
+ recon = self.dec(self.enc(X))
78
+ return nn.MSELoss()(recon, X)
79
+
80
+
81
+ class MidMetaNet(nn.Module):
82
+ """The hidden Translator of MetaFormer for SimVP"""
83
+ # Modify HERE with an additional param: channel_out
84
+ def __init__(self, channel_in, channel_out, channel_hid, N2,
85
+ input_resolution=None, model_type=None,
86
+ mlp_ratio=4., drop=0.0, drop_path=0.1):
87
+ super(MidMetaNet, self).__init__()
88
+ assert N2 >= 2 and mlp_ratio > 1
89
+ self.N2 = N2
90
+ dpr = [ # stochastic depth decay rule
91
+ x.item() for x in torch.linspace(1e-2, drop_path, self.N2)]
92
+
93
+ # downsample
94
+ enc_layers = [MetaBlock(
95
+ channel_in, channel_hid, input_resolution, model_type,
96
+ mlp_ratio, drop, drop_path=dpr[0], layer_i=0)]
97
+ # middle layers
98
+ for i in range(1, N2-1):
99
+ enc_layers.append(MetaBlock(
100
+ channel_hid, channel_hid, input_resolution, model_type,
101
+ mlp_ratio, drop, drop_path=dpr[i], layer_i=i))
102
+
103
+ # upsample
104
+ # Modify HERE
105
+ enc_layers.append(MetaBlock(
106
+ channel_hid, channel_out, input_resolution, model_type,
107
+ mlp_ratio, drop, drop_path=drop_path, layer_i=N2-1))
108
+ self.enc = nn.Sequential(*enc_layers)
109
+
110
+ def forward(self, x):
111
+ B, T, C, H, W = x.shape
112
+ x = x.reshape(B, T*C, H, W)
113
+
114
+ z = x
115
+ conds = [z]
116
+ for i in range(self.N2):
117
+ z = self.enc[i](z)
118
+ conds.append(z)
119
+
120
+ y = z.reshape(B, -1, C, H, W)
121
+ return y, y.reshape(-1, C, H, W) #conds #conds[:-1]
122
+
123
+ class MetaBlock(nn.Module):
124
+ """The hidden Translator of MetaFormer for SimVP"""
125
+
126
+ def __init__(self, in_channels, out_channels, input_resolution=None, model_type=None,
127
+ mlp_ratio=8., drop=0.0, drop_path=0.0, layer_i=0):
128
+ super(MetaBlock, self).__init__()
129
+ self.in_channels = in_channels
130
+ self.out_channels = out_channels
131
+ model_type = model_type.lower() if model_type is not None else 'gsta'
132
+
133
+ if model_type == 'gsta':
134
+ self.block = GASubBlock(
135
+ in_channels, kernel_size=21, mlp_ratio=mlp_ratio,
136
+ drop=drop, drop_path=drop_path, act_layer=nn.GELU)
137
+ else:
138
+ assert False and "Invalid model_type in SimVP"
139
+
140
+ if in_channels != out_channels:
141
+ self.reduction = nn.Conv2d(
142
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0)
143
+
144
+ def forward(self, x):
145
+ z = self.block(x)
146
+ return z if self.in_channels == self.out_channels else self.reduction(z)
147
+
148
+ class GASubBlock(nn.Module):
149
+ """A GABlock (gSTA) for SimVP"""
150
+
151
+ def __init__(self, dim, kernel_size=21, mlp_ratio=4.,
152
+ drop=0., drop_path=0.1, init_value=1e-2, act_layer=nn.GELU):
153
+ super().__init__()
154
+ self.norm1 = nn.BatchNorm2d(dim)
155
+ self.attn = SpatialAttention(dim, kernel_size)
156
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
157
+
158
+ self.norm2 = nn.BatchNorm2d(dim)
159
+ mlp_hidden_dim = int(dim * mlp_ratio)
160
+ self.mlp = MixMlp(
161
+ in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
162
+
163
+ self.layer_scale_1 = nn.Parameter(init_value * torch.ones((dim)), requires_grad=True)
164
+ self.layer_scale_2 = nn.Parameter(init_value * torch.ones((dim)), requires_grad=True)
165
+
166
+ self.apply(self._init_weights)
167
+
168
+ def _init_weights(self, m):
169
+ if isinstance(m, nn.Linear):
170
+ trunc_normal_(m.weight, std=.02)
171
+ if isinstance(m, nn.Linear) and m.bias is not None:
172
+ nn.init.constant_(m.bias, 0)
173
+ elif isinstance(m, nn.LayerNorm):
174
+ nn.init.constant_(m.bias, 0)
175
+ nn.init.constant_(m.weight, 1.0)
176
+ elif isinstance(m, nn.Conv2d):
177
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
178
+ fan_out //= m.groups
179
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
180
+ if m.bias is not None:
181
+ m.bias.data.zero_()
182
+
183
+ @torch.jit.ignore
184
+ def no_weight_decay(self):
185
+ return {'layer_scale_1', 'layer_scale_2'}
186
+
187
+ def forward(self, x):
188
+ x = x + self.drop_path(
189
+ self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * self.attn(self.norm1(x)))
190
+ x = x + self.drop_path(
191
+ self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.mlp(self.norm2(x)))
192
+ return x
193
+
194
+ class SpatialAttention(nn.Module):
195
+ """A Spatial Attention block for SimVP"""
196
+
197
+ def __init__(self, d_model, kernel_size=21, attn_shortcut=True):
198
+ super().__init__()
199
+
200
+ self.proj_1 = nn.Conv2d(d_model, d_model, 1) # 1x1 conv
201
+ self.activation = nn.GELU() # GELU
202
+ self.spatial_gating_unit = AttentionModule(d_model, kernel_size)
203
+ self.proj_2 = nn.Conv2d(d_model, d_model, 1) # 1x1 conv
204
+ self.attn_shortcut = attn_shortcut
205
+
206
+ def forward(self, x):
207
+ if self.attn_shortcut:
208
+ shortcut = x.clone()
209
+ x = self.proj_1(x)
210
+ x = self.activation(x)
211
+ x = self.spatial_gating_unit(x)
212
+ x = self.proj_2(x)
213
+ if self.attn_shortcut:
214
+ x = x + shortcut
215
+ return x
216
+
217
+ class AttentionModule(nn.Module):
218
+ """Large Kernel Attention for SimVP"""
219
+
220
+ def __init__(self, dim, kernel_size, dilation=3):
221
+ super().__init__()
222
+ d_k = 2 * dilation - 1
223
+ d_p = (d_k - 1) // 2
224
+ dd_k = kernel_size // dilation + ((kernel_size // dilation) % 2 - 1)
225
+ dd_p = (dilation * (dd_k - 1) // 2)
226
+
227
+ self.conv0 = nn.Conv2d(dim, dim, d_k, padding=d_p, groups=dim)
228
+ self.conv_spatial = nn.Conv2d(
229
+ dim, dim, dd_k, stride=1, padding=dd_p, groups=dim, dilation=dilation)
230
+ self.conv1 = nn.Conv2d(dim, 2*dim, 1)
231
+
232
+ def forward(self, x):
233
+ u = x.clone()
234
+ attn = self.conv0(x) # depth-wise conv
235
+ attn = self.conv_spatial(attn) # depth-wise dilation convolution
236
+
237
+ f_g = self.conv1(attn)
238
+ split_dim = f_g.shape[1] // 2
239
+ f_x, g_x = torch.split(f_g, split_dim, dim=1)
240
+ return torch.sigmoid(g_x) * f_x
241
+
242
+ class DWConv(nn.Module):
243
+ def __init__(self, dim=768):
244
+ super(DWConv, self).__init__()
245
+ self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
246
+
247
+ def forward(self, x):
248
+ x = self.dwconv(x)
249
+ return x
250
+
251
+
252
+ class MixMlp(nn.Module):
253
+ def __init__(self,
254
+ in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
255
+ super().__init__()
256
+ out_features = out_features or in_features
257
+ hidden_features = hidden_features or in_features
258
+ self.fc1 = nn.Conv2d(in_features, hidden_features, 1) # 1x1
259
+ self.dwconv = DWConv(hidden_features) # CFF: Convlutional feed-forward network
260
+ self.act = act_layer() # GELU
261
+ self.fc2 = nn.Conv2d(hidden_features, out_features, 1) # 1x1
262
+ self.drop = nn.Dropout(drop)
263
+ self.apply(self._init_weights)
264
+
265
+ def _init_weights(self, m):
266
+ if isinstance(m, nn.Linear):
267
+ trunc_normal_(m.weight, std=.02)
268
+ if isinstance(m, nn.Linear) and m.bias is not None:
269
+ nn.init.constant_(m.bias, 0)
270
+ elif isinstance(m, nn.LayerNorm):
271
+ nn.init.constant_(m.bias, 0)
272
+ nn.init.constant_(m.weight, 1.0)
273
+ elif isinstance(m, nn.Conv2d):
274
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
275
+ fan_out //= m.groups
276
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
277
+ if m.bias is not None:
278
+ m.bias.data.zero_()
279
+
280
+ def forward(self, x):
281
+ x = self.fc1(x)
282
+ x = self.dwconv(x)
283
+ x = self.act(x)
284
+ x = self.drop(x)
285
+ x = self.fc2(x)
286
+ x = self.drop(x)
287
+ return x
288
+
289
+
290
+ """
291
+ From TIMM repo: https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
292
+ """
293
+ def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
294
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
295
+
296
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
297
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
298
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
299
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
300
+ 'survival rate' as the argument.
301
+
302
+ """
303
+ if drop_prob == 0. or not training:
304
+ return x
305
+ keep_prob = 1 - drop_prob
306
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
307
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
308
+ if keep_prob > 0.0 and scale_by_keep:
309
+ random_tensor.div_(keep_prob)
310
+ return x * random_tensor
311
+
312
+
313
+ class DropPath(nn.Module):
314
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
315
+ """
316
+ def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
317
+ super(DropPath, self).__init__()
318
+ self.drop_prob = drop_prob
319
+ self.scale_by_keep = scale_by_keep
320
+
321
+ def forward(self, x):
322
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
323
+
324
+ def extra_repr(self):
325
+ return f'drop_prob={round(self.drop_prob,3):0.3f}'
326
+
327
+ def _trunc_normal_(tensor, mean, std, a, b):
328
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
329
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
330
+ def norm_cdf(x):
331
+ # Computes standard normal cumulative distribution function
332
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
333
+
334
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
335
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
336
+ "The distribution of values may be incorrect.",
337
+ stacklevel=2)
338
+
339
+ # Values are generated by using a truncated uniform distribution and
340
+ # then using the inverse CDF for the normal distribution.
341
+ # Get upper and lower cdf values
342
+ l = norm_cdf((a - mean) / std)
343
+ u = norm_cdf((b - mean) / std)
344
+
345
+ # Uniformly fill tensor with values from [l, u], then translate to
346
+ # [2l-1, 2u-1].
347
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
348
+
349
+ # Use inverse cdf transform for normal distribution to get truncated
350
+ # standard normal
351
+ tensor.erfinv_()
352
+
353
+ # Transform to proper mean, std
354
+ tensor.mul_(std * math.sqrt(2.))
355
+ tensor.add_(mean)
356
+
357
+ # Clamp to ensure it's in the proper range
358
+ tensor.clamp_(min=a, max=b)
359
+ return tensor
360
+
361
+ class Encoder(nn.Module):
362
+ def __init__(self,C_in, C_hid, N_S):
363
+ super(Encoder,self).__init__()
364
+ strides = stride_generator(N_S)
365
+ self.enc = nn.Sequential(
366
+ ConvSC(C_in, C_hid, stride=strides[0]),
367
+ *[ConvSC(C_hid, C_hid, stride=s) for s in strides[1:]]
368
+ )
369
+
370
+ def forward(self,x):# B*4, 3, 128, 128
371
+ enc1 = self.enc[0](x)
372
+ latent = enc1
373
+ for i in range(1,len(self.enc)):
374
+ latent = self.enc[i](latent)
375
+ return latent
376
+
377
+ class Decoder(nn.Module):
378
+ def __init__(self,C_hid, C_out, N_S, last_activation='sigmoid'):
379
+ super(Decoder,self).__init__()
380
+ strides = stride_generator(N_S, reverse=True)
381
+ self.dec = nn.Sequential(
382
+ *[ConvSC(C_hid, C_hid, stride=s, transpose=True) for s in strides[:-1]],
383
+ ConvSC(C_hid, C_hid, stride=strides[-1], transpose=True)# Modify HERE
384
+ )
385
+ self.readout = nn.Conv2d(C_hid, C_out, 1)
386
+ if last_activation=='sigmoid':
387
+ self.last = nn.Sigmoid()
388
+ else:
389
+ self.last = nn.Identity()
390
+
391
+ def forward(self, hid):
392
+ for i in range(0,len(self.dec)-1):
393
+ hid = self.dec[i](hid)
394
+ Y = self.dec[-1](hid) # Modify HERE
395
+ Y = self.readout(Y)
396
+ return self.last(Y)
397
+
398
+ class BasicConv2d(nn.Module):
399
+ def __init__(self, in_channels, out_channels, kernel_size, stride, padding, transpose=False, act_norm=False):
400
+ super(BasicConv2d, self).__init__()
401
+ self.act_norm=act_norm
402
+ if not transpose:
403
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
404
+ else:
405
+ self.conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,output_padding=stride //2 )
406
+ self.norm = nn.GroupNorm(2, out_channels)
407
+ self.act = nn.LeakyReLU(0.2, inplace=True)
408
+
409
+ def forward(self, x):
410
+ y = self.conv(x)
411
+ if self.act_norm:
412
+ y = self.act(self.norm(y))
413
+ return y
414
+
415
+
416
+ class ConvSC(nn.Module):
417
+ def __init__(self, C_in, C_out, stride, transpose=False, act_norm=True):
418
+ super(ConvSC, self).__init__()
419
+ if stride == 1:
420
+ transpose = False
421
+ self.conv = BasicConv2d(C_in, C_out, kernel_size=3, stride=stride,
422
+ padding=1, transpose=transpose, act_norm=act_norm)
423
+
424
+ def forward(self, x):
425
+ y = self.conv(x)
426
+ return y
427
+
428
+ def stride_generator(N, reverse=False):
429
+ strides = [1, 2]*10
430
+ if reverse: return list(reversed(strides[:N]))
431
+ else: return strides[:N]
stldm/stldm.py ADDED
@@ -0,0 +1,612 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, random
2
+ from torch import nn
3
+ from einops import rearrange
4
+
5
+ from stldm.submodules import *
6
+
7
+ class Down_Block(nn.Module):
8
+ def __init__(self, in_ch, hid_ch, out_ch, time_dim, is_last, patch_size=None, num_groups=8, heads=4, dim_head=32):
9
+ super(Down_Block, self).__init__()
10
+ self.block1 = ResnetBlock(dim=in_ch, dim_out=hid_ch, time_emb_dim=time_dim, groups=num_groups)
11
+ self.attn_spatial = Residual(PreNorm(hid_ch, Quadratic_SpatialAttention(dim=hid_ch, heads=heads, dim_head=dim_head))) if patch_size is None else Residual(PreNorm(hid_ch, Linear_SpatialAttention(dim=hid_ch, patch_size=patch_size, heads=heads, dim_head=dim_head)))
12
+ self.block2 = ResnetBlock(dim=hid_ch, dim_out=hid_ch, groups=num_groups)
13
+ # self.attn_temporal = Residual(PreNorm(hid_ch, TemporalAttention_Pos(dim=hid_ch, heads=heads, dim_head=dim_head)))
14
+ self.attn_temporal = Residual(PreNorm(hid_ch, TemporalAttention(dim=hid_ch, heads=heads, dim_head=dim_head)))
15
+ self.last = Downsample2D(dim_in=hid_ch, dim_out=out_ch) if not is_last else ChannelConversion(hid_ch, out_ch)
16
+
17
+ def forward(self, x, time_emb, cond=None, relative_pos=None):
18
+ assert x.ndim==5
19
+ B, T, C, H, W = x.shape
20
+
21
+ x = x.reshape(B*T, C, H, W)
22
+ if cond is None:
23
+ cond = torch.zeros_like(x) # -> Unconditioning
24
+
25
+ time_emb = time_emb.unsqueeze(1) # From (B C) to (B 1 C)
26
+ time_emb = time_emb.repeat(1, T, 1)
27
+ time_emb = time_emb.reshape(B*T, -1)
28
+
29
+ out = torch.cat((x, cond), dim=1) # BT, 2C, H, W
30
+ out = self.block1(out, time_emb)
31
+
32
+ spatial_attn = self.attn_spatial(out)
33
+ out = self.block2(spatial_attn, time_emb)
34
+ *_, c, h, w = out.shape
35
+ out = out.reshape(B,T,c,h,w)
36
+
37
+ # temporal_attn = self.attn_temporal(out, relative_pos)
38
+ temporal_attn = self.attn_temporal(out)
39
+ temporal_attn = temporal_attn.reshape(B*T,c,h,w)
40
+
41
+ out = self.last(temporal_attn)
42
+ *_, c, h, w = out.shape
43
+
44
+ return out.reshape(B, T, c, h, w), spatial_attn, temporal_attn
45
+
46
+ class MidBlock(nn.Module):
47
+ def __init__(self, in_ch, time_dim, num_groups=8, heads=4, dim_head=32):
48
+ super(MidBlock, self).__init__()
49
+ self.block1 = ResnetBlock(dim=in_ch, dim_out=in_ch, time_emb_dim=time_dim, groups=num_groups)
50
+ self.qattn_spatial = Residual(PreNorm(in_ch, Quadratic_SpatialAttention(dim=in_ch, heads=heads, dim_head=dim_head)))
51
+ self.block2 = ResnetBlock(dim=in_ch, dim_out=in_ch, time_emb_dim=time_dim, groups=num_groups)
52
+ # self.qattn_time = Residual(PreNorm(in_ch, TemporalAttention_Pos(dim=in_ch, heads=heads, dim_head=dim_head)))
53
+ self.qattn_time = Residual(PreNorm(in_ch, TemporalAttention(dim=in_ch, heads=heads, dim_head=dim_head)))
54
+ self.block3 = ResnetBlock(dim=in_ch, dim_out=in_ch, time_emb_dim=time_dim, groups=num_groups)
55
+
56
+ def forward(self, x, time_emb, relative_pos=None):
57
+ assert x.ndim==5
58
+ B, T, C, H, W = x.shape
59
+ x = x.reshape(B*T, C, H, W)
60
+
61
+ time_emb = time_emb.unsqueeze(1) # From (B C) to (B 1 C)
62
+ time_emb = time_emb.repeat(1, T, 1)
63
+ time_emb = time_emb.reshape(B*T, -1)
64
+
65
+ out = self.block1(x, time_emb)
66
+ out = self.qattn_spatial(out)
67
+ out = self.block2(out, time_emb) # a little bit difference here
68
+
69
+ out = out.reshape((B, T, C, H, W))
70
+ # out = self.qattn_time(out, relative_pos).reshape(B*T, C, H, W)
71
+ out = self.qattn_time(out).reshape(B*T, C, H, W)
72
+ out = self.block3(out, time_emb)
73
+
74
+ *_, c, h, w = out.shape
75
+ return out.reshape(B, T, c, h, w)
76
+
77
+ class Up_Block(nn.Module):
78
+ def __init__(self, in_chs, hid_ch, out_ch, is_last, time_dim, patch_size=None, num_groups=8, heads=4, dim_head=32):
79
+ super(Up_Block, self).__init__()
80
+ in_ch, skip_ch = in_chs
81
+ self.up = Upsample2D(dim_in=in_ch, dim_out=hid_ch) if not is_last else ChannelConversion(in_ch, hid_ch)
82
+ self.attn_spatial = Residual(PreNorm(hid_ch, Quadratic_SpatialAttention(dim=hid_ch, heads=heads, dim_head=dim_head) if patch_size is None else Linear_SpatialAttention(dim=hid_ch, patch_size=patch_size, heads=heads, dim_head=dim_head)))
83
+ self.block1 = ResnetBlock(dim=hid_ch+skip_ch, dim_out=hid_ch, time_emb_dim=time_dim, groups=num_groups)
84
+ # self.attn_temporal = Residual(PreNorm(hid_ch, TemporalAttention_Pos(dim=hid_ch, heads=heads, dim_head=dim_head)))
85
+ self.attn_temporal = Residual(PreNorm(hid_ch, TemporalAttention(dim=hid_ch, heads=heads, dim_head=dim_head)))
86
+ self.block2 = ResnetBlock(dim=hid_ch+skip_ch, dim_out=out_ch, time_emb_dim=time_dim, groups=num_groups)
87
+
88
+ def forward(self, x, time_emb, spatialattn_skip, tempattn_skip, relative_pos=None):
89
+ assert x.ndim==5
90
+ B, T, C, H, W = x.shape
91
+ x = x.reshape(B*T, C, H, W)
92
+
93
+ time_emb = time_emb.unsqueeze(1) # From (B C) to (B 1 C)
94
+ time_emb = time_emb.repeat(1, T, 1)
95
+ time_emb = time_emb.reshape(B*T, -1)
96
+
97
+ out = self.up(x)
98
+ *_, c, h, w = out.shape
99
+ out = out.reshape(-1, T, c, h, w)
100
+
101
+ # out = self.attn_temporal(out, relative_pos).reshape(B*T, c, h, w)
102
+ out = self.attn_temporal(out).reshape(B*T, c, h, w)
103
+
104
+ out = torch.cat((out, tempattn_skip), dim=1)
105
+ out = self.block1(out, time_emb)
106
+
107
+ out = self.attn_spatial(out)
108
+
109
+ out = torch.cat((out, spatialattn_skip), dim=1)
110
+ out = self.block2(out, time_emb)
111
+ *_, c, h, w = out.shape
112
+ return out.reshape(B, T, c, h, w)
113
+
114
+ class LDM(nn.Module):
115
+ def __init__(self, in_ch, chs_mult:tuple, patch_size=None, num_groups=8, heads=4, dim_head=32, base_ch=64):
116
+ super(LDM, self).__init__()
117
+ # Time Embedding MLP
118
+ time_dim = 4*base_ch
119
+ fourier_dim = base_ch
120
+ self.time_mlp = Time_MLP(dim=base_ch, time_dim=time_dim, fourier_dim=fourier_dim)
121
+
122
+ ups, downs = [], []
123
+ conditions = []
124
+
125
+ layer_no = len(chs_mult)
126
+ chs = [in_ch, *map(lambda m: base_ch*m, chs_mult)]
127
+ ch_in, ch_out = chs[:-1], chs[1:]
128
+ up_in, up_out = list(reversed(ch_out)), list(reversed(ch_in))
129
+
130
+ patches = None if patch_size is None else [patch_size//(2**n) for n in range(layer_no)] # Patch Size should be 2^N
131
+ for n in range(layer_no):
132
+ downs.append(
133
+ Down_Block(in_ch=2*ch_in[n], hid_ch=ch_in[n], out_ch=ch_out[n], time_dim=time_dim, patch_size=None if patch_size is None else patches[n], is_last=(n==layer_no-1), num_groups=num_groups, heads=heads, dim_head=dim_head)
134
+ )
135
+ ups.append(
136
+ Up_Block(in_chs=(up_in[n], ch_in[-n-1]), hid_ch=up_in[n], out_ch=up_out[n], time_dim=time_dim, patch_size=None if patch_size is None else patches[layer_no-n-1], is_last=(n==0), num_groups=num_groups, heads=heads, dim_head=dim_head)
137
+ )
138
+ if n != -1:
139
+ conditions.append(
140
+ Downsample2D(dim_in=ch_in[n], dim_out=ch_out[n])
141
+ )
142
+
143
+ self.downs = nn.ModuleList(downs)
144
+ self.ups = nn.ModuleList(ups)
145
+ self.conditions = nn.ModuleList(conditions)
146
+ self.mid = MidBlock(in_ch=ch_out[-1], time_dim=time_dim, num_groups=num_groups, heads=heads, dim_head=dim_head)
147
+ # self.relative_pos = RelativePositionBias(heads=heads)
148
+
149
+ def forward(self, x, time, conds=None):
150
+ t = self.time_mlp(time)
151
+
152
+ hid_spatial = []
153
+ hid_temporal = []
154
+
155
+ # relative_position = self.relative_pos(x.shape[1], x.device) # Calculate The Relative Position
156
+
157
+ for n, down_block in enumerate(self.downs):
158
+ # print(x.shape)
159
+ # x, spatial_attn, time_attn = down_block(x, t, conds, relative_position)
160
+ x, spatial_attn, time_attn = down_block(x, t, conds)
161
+ hid_spatial.append(spatial_attn)
162
+ hid_temporal.append(time_attn)
163
+ if conds is not None:
164
+ conds = self.conditions[n](conds)
165
+
166
+ # out = self.mid(x, t, relative_position)
167
+ out = self.mid(x, t)
168
+
169
+ for up_block in self.ups:
170
+ # out = up_block(out, t, hid_spatial.pop(), hid_temporal.pop(), relative_position)
171
+ out = up_block(out, t, hid_spatial.pop(), hid_temporal.pop())
172
+
173
+ return out
174
+
175
+ # constants
176
+ from collections import namedtuple
177
+ from torch.cuda.amp import autocast
178
+ import torch.nn.functional as F
179
+ from einops import reduce
180
+ from tqdm.auto import tqdm
181
+
182
+ ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start'])
183
+
184
+ def identity(t, *args, **kwargs):
185
+ return t
186
+
187
+ def extract(a, t, x_shape):
188
+ b, *_ = t.shape
189
+ out = a.gather(-1, t)
190
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
191
+
192
+ def default(val, d):
193
+ if exists(val):
194
+ return val
195
+ return d() if callable(d) else d
196
+
197
+ def exists(x):
198
+ return x is not None
199
+
200
+ def guidance_scheduler(sampling_step: int, const: float):
201
+ return const*torch.ones(sampling_step)
202
+
203
+ class GaussianDiffusion(nn.Module):
204
+ def __init__(
205
+ self,
206
+ vp_model,
207
+ diffusion,
208
+ timesteps = 1000,
209
+ sampling_timesteps = None,
210
+ objective = 'pred_v',
211
+ beta_schedule = 'sigmoid',
212
+ schedule_fn_kwargs = dict(),
213
+ ddim_sampling_eta = 0.,
214
+ offset_noise_strength = 0., # https://www.crosslabs.org/blog/diffusion-with-offset-noise
215
+ min_snr_loss_weight = False, # https://arxiv.org/abs/2303.09556
216
+ min_snr_gamma = 5
217
+ ):
218
+ super(GaussianDiffusion, self).__init__()
219
+
220
+ self.backbone = vp_model
221
+ self.diff_unet = diffusion
222
+
223
+ self.objective = objective
224
+ assert objective in {'pred_noise', 'pred_x0', 'pred_v'}, 'objective must be either pred_noise (predict noise) or pred_x0 (predict image start) or pred_v (predict v [v-parameterization as defined in appendix D of progressive distillation paper, used in imagen-video successfully])'
225
+
226
+ if beta_schedule == 'linear':
227
+ beta_schedule_fn = linear_beta_schedule
228
+ elif beta_schedule == 'cosine':
229
+ beta_schedule_fn = cosine_beta_schedule
230
+ elif beta_schedule == 'sigmoid':
231
+ beta_schedule_fn = sigmoid_beta_schedule
232
+ else:
233
+ raise ValueError(f'unknown beta schedule {beta_schedule}')
234
+
235
+ betas = beta_schedule_fn(timesteps, **schedule_fn_kwargs)
236
+
237
+ alphas = 1. - betas
238
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
239
+ alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
240
+
241
+ timesteps, = betas.shape
242
+ self.num_timesteps = int(timesteps)
243
+
244
+ # sampling related parameters
245
+
246
+ self.sampling_timesteps = default(sampling_timesteps, timesteps) # default num sampling timesteps to number of timesteps at training
247
+
248
+ assert self.sampling_timesteps <= timesteps
249
+ self.is_ddim_sampling = self.sampling_timesteps < timesteps
250
+ self.ddim_sampling_eta = ddim_sampling_eta
251
+
252
+ # helper function to register buffer from float64 to float32
253
+
254
+ register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
255
+
256
+ register_buffer('betas', betas)
257
+ register_buffer('alphas_cumprod', alphas_cumprod)
258
+ register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
259
+
260
+ # calculations for diffusion q(x_t | x_{t-1}) and others
261
+
262
+ register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
263
+ register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
264
+ register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
265
+ register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
266
+ register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
267
+
268
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
269
+
270
+ posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
271
+
272
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
273
+
274
+ register_buffer('posterior_variance', posterior_variance)
275
+
276
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
277
+
278
+ register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
279
+ register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
280
+ register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
281
+
282
+ # offset noise strength - in blogpost, they claimed 0.1 was ideal
283
+
284
+ self.offset_noise_strength = offset_noise_strength
285
+
286
+ # derive loss weight
287
+ # snr - signal noise ratio
288
+
289
+ snr = alphas_cumprod / (1 - alphas_cumprod)
290
+
291
+ # https://arxiv.org/abs/2303.09556
292
+
293
+ maybe_clipped_snr = snr.clone()
294
+ if min_snr_loss_weight:
295
+ maybe_clipped_snr.clamp_(max = min_snr_gamma)
296
+
297
+ if objective == 'pred_noise':
298
+ register_buffer('loss_weight', maybe_clipped_snr / snr)
299
+ elif objective == 'pred_x0':
300
+ register_buffer('loss_weight', maybe_clipped_snr)
301
+ elif objective == 'pred_v':
302
+ register_buffer('loss_weight', maybe_clipped_snr / (snr + 1))
303
+
304
+ @property
305
+ def device(self):
306
+ return self.betas.device
307
+
308
+ # CFG schdeuler => by taking pre-setting scheduler
309
+ def setup_guidance(self, scheduler):
310
+ if exists(scheduler):
311
+ self.CFG_sch = scheduler.to(self.device)
312
+ else:
313
+ self.CFG_sch = scheduler
314
+
315
+ def predict_start_from_noise(self, x_t, t, noise):
316
+ return (
317
+ extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
318
+ extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
319
+ )
320
+
321
+ def predict_noise_from_start(self, x_t, t, x0):
322
+ return (
323
+ (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \
324
+ extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
325
+ )
326
+
327
+ def predict_v(self, x_start, t, noise):
328
+ return (
329
+ extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise -
330
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start
331
+ )
332
+
333
+ def predict_start_from_v(self, x_t, t, v):
334
+ return (
335
+ extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
336
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
337
+ )
338
+
339
+ def q_posterior(self, x_start, x_t, t):
340
+ posterior_mean = (
341
+ extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
342
+ extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
343
+ )
344
+ posterior_variance = extract(self.posterior_variance, t, x_t.shape)
345
+ posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
346
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
347
+
348
+ def model_predictions(self, x, t, cond, clip_x_start = False, rederive_pred_noise = False):
349
+ # print(t.device)
350
+ if exists(self.CFG_sch):
351
+ uncond = self.diff_unet(x, t, conds=None) #conds=torch.zeros_like(cond))
352
+ model_output = self.diff_unet(x, t, conds=cond)
353
+ time = int(t[0])
354
+ model_output = model_output - self.CFG_sch[time] * (uncond - model_output)
355
+ else:
356
+ model_output = self.diff_unet(x, t, conds=cond)
357
+ maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity
358
+
359
+ if self.objective == 'pred_noise':
360
+ pred_noise = model_output
361
+ x_start = self.predict_start_from_noise(x, t, pred_noise)
362
+ x_start = maybe_clip(x_start)
363
+
364
+ if clip_x_start and rederive_pred_noise:
365
+ pred_noise = self.predict_noise_from_start(x, t, x_start)
366
+
367
+ elif self.objective == 'pred_x0':
368
+ x_start = model_output
369
+ x_start = maybe_clip(x_start)
370
+ pred_noise = self.predict_noise_from_start(x, t, x_start)
371
+
372
+ elif self.objective == 'pred_v':
373
+ v = model_output
374
+ x_start = self.predict_start_from_v(x, t, v)
375
+ x_start = maybe_clip(x_start)
376
+ pred_noise = self.predict_noise_from_start(x, t, x_start)
377
+
378
+ return ModelPrediction(pred_noise, x_start)
379
+
380
+ def p_mean_variance(self, x, t, cond=None, clip_denoised = True):
381
+ preds = self.model_predictions(x, t, cond=cond, clip_x_start=False,)
382
+ x_start = preds.pred_x_start
383
+
384
+ if clip_denoised:
385
+ x_start.clamp_(-1., 1.)
386
+
387
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t)
388
+ return model_mean, posterior_variance, posterior_log_variance, x_start
389
+
390
+ @torch.no_grad()
391
+ def p_sample(self, x, t: int, cond=None):
392
+ b, *_, device = *x.shape, self.device
393
+ batched_times = torch.full((b,), t, device = device, dtype = torch.long)
394
+ model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, cond=cond, clip_denoised = False)
395
+ noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0
396
+ pred_img = model_mean + (0.5 * model_log_variance).exp() * noise
397
+ return pred_img, x_start
398
+
399
+ @torch.no_grad()
400
+ def p_sample_loop(self, shape, cond=None, return_all_timesteps = False):
401
+ batch, device = shape[0], self.device
402
+
403
+ frames_pred = torch.randn(shape, device = device)
404
+ imgs = [frames_pred]
405
+
406
+ for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps, disable=True):
407
+ frames_pred, _ = self.p_sample(frames_pred, t, cond=cond)
408
+ imgs.append(frames_pred)
409
+
410
+ ret = frames_pred if not return_all_timesteps else torch.stack(imgs, dim = 1)
411
+ return ret
412
+
413
+ @torch.no_grad()
414
+ def ddim_sample(self, shape, cond=None, return_all_timesteps = False):
415
+ batch, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective
416
+ device = self.device
417
+ times = torch.linspace(-1, total_timesteps - 1, steps = sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
418
+ times = list(reversed(times.int().tolist()))
419
+ time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
420
+
421
+ frames_pred = torch.randn(shape, device = device)
422
+ imgs = [frames_pred]
423
+
424
+ for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step', disable=True):
425
+ time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
426
+ pred_noise, x_start, *_ = self.model_predictions(
427
+ frames_pred,
428
+ time_cond,
429
+ cond = cond, #cond.copy(),
430
+ clip_x_start = False,
431
+ rederive_pred_noise = True
432
+ )
433
+
434
+ if time_next < 0:
435
+ frames_pred = x_start
436
+ imgs.append(frames_pred)
437
+ continue
438
+
439
+ alpha = self.alphas_cumprod[time]
440
+ alpha_next = self.alphas_cumprod[time_next]
441
+
442
+ sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
443
+ c = (1 - alpha_next - sigma ** 2).sqrt()
444
+
445
+ noise = torch.randn_like(frames_pred)
446
+
447
+ frames_pred = x_start * alpha_next.sqrt() + \
448
+ c * pred_noise + \
449
+ sigma * noise
450
+
451
+ imgs.append(frames_pred)
452
+
453
+ ret = frames_pred if not return_all_timesteps else torch.stack(imgs, dim = 1)
454
+ return ret
455
+
456
+ @torch.no_grad()
457
+ def sample(self, frames_in, return_all_timesteps = False):
458
+ assert frames_in.ndim == 5
459
+ B, T_in, C, H, W = frames_in.shape
460
+ device = self.device
461
+
462
+ backbone_output, conds, *_ = self.backbone(frames_in)
463
+ sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample
464
+
465
+ *_, c, h, w = conds.shape
466
+ tgt_shape = conds.reshape(B, -1, c, h, w).shape
467
+ ldm_pred = sample_fn(
468
+ tgt_shape,
469
+ cond=conds,
470
+ return_all_timesteps = return_all_timesteps
471
+ )
472
+
473
+ ldm_pred = rearrange(ldm_pred, 'b t c h w -> (b t) c h w')
474
+ frames_pred = self.backbone.vae.decode(ldm_pred)
475
+ frames_pred = rearrange(frames_pred, '(b t) c h w -> b t c h w', b=B)
476
+ return frames_pred, backbone_output
477
+
478
+ def predict(self, frames_in, compute_loss=False, **kwargs):
479
+ pred, mu = self.sample(frames_in=frames_in)
480
+ return pred, mu
481
+
482
+ def compute_loss(self, frames_in, frames_gt, validate=False):
483
+ compute_loss = True and (not validate)
484
+ B, T_in, C, H, W = frames_in.shape
485
+ T_out = frames_gt.shape[1]
486
+ device = frames_in.device
487
+
488
+ """
489
+ Diffusion Loss
490
+ """
491
+ backbone_output, conds = self.backbone(frames_in)
492
+ hid_gt, _ = self.backbone.vae.encode(
493
+ rearrange(frames_gt, 'b t c h w -> (b t) c h w')
494
+ )
495
+ hid_gt = rearrange(hid_gt, '(b t) c h w -> b t c h w', b=B)
496
+ t = torch.randint(0, self.num_timesteps, (B,), device=self.device).long()
497
+ if random.random() > 0.85: # Unconditional
498
+ conds = None
499
+ diff_loss = self.p_losses(hid_gt.detach(), t, cond=conds)
500
+
501
+ """
502
+ Backbone Loss
503
+ """
504
+ mu_loss = self.backbone._losses_(frames_in, frames_gt)
505
+
506
+ """
507
+ VAE Loss
508
+ """
509
+ ae_loss, kl_loss = self.backbone.vae._losses_(
510
+ rearrange(torch.cat((frames_in, frames_gt), dim=1), 'b t c h w -> (b t) c h w'),
511
+ rearrange(torch.cat((frames_in, frames_gt), dim=1), 'b t c h w -> (b t) c h w')
512
+ )
513
+ kl_weight = 1E-6
514
+ recon_loss = ae_loss + kl_weight*kl_loss
515
+
516
+ """
517
+ Prior Loss at t=T [Noisy]
518
+ """
519
+ hid_gt, _ = self.backbone.vae.encode(
520
+ rearrange(frames_gt, 'b t c h w -> (b t) c h w')
521
+ )
522
+ hid_gt = rearrange(hid_gt, '(b t) c h w -> b t c h w', b=B)
523
+ T = torch.ones((B,), device=self.device).long() * (self.num_timesteps - 1)
524
+ mu_noisy = extract(self.sqrt_alphas_cumprod, T, hid_gt.shape) * hid_gt
525
+ sigma_noisy = extract(self.sqrt_one_minus_alphas_cumprod, T, hid_gt.shape)
526
+ log_var_noisy = 2*torch.log(sigma_noisy)
527
+ prior_loss = self.kl_from_standard_normal(mu_noisy, log_var_noisy)
528
+
529
+ return recon_loss, mu_loss, diff_loss, prior_loss
530
+
531
+
532
+ def kl_from_standard_normal(self, mean, log_var):
533
+ kl = 0.5 * (log_var.exp() + mean.square() - 1.0 - log_var)
534
+ return kl.mean()
535
+
536
+ @autocast(enabled = False)
537
+ def q_sample(self, x_start, t, noise = None):
538
+ noise = default(noise, lambda: torch.randn_like(x_start))
539
+
540
+ return (
541
+ extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
542
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
543
+ )
544
+
545
+ def p_losses(self, x_start, t, noise=None, offset_noise_strength=None, cond=None):
546
+ b, T, c, h, w = x_start.shape
547
+
548
+ noise = default(noise, lambda: torch.randn_like(x_start))
549
+
550
+ # offset noise - https://www.crosslabs.org/blog/diffusion-with-offset-noise
551
+ offset_noise_strength = default(offset_noise_strength, self.offset_noise_strength)
552
+
553
+ if offset_noise_strength > 0.:
554
+ offset_noise = torch.randn(x_start.shape[:2], device = self.device)
555
+ noise += offset_noise_strength * rearrange(offset_noise, 'b c -> b c 1 1')
556
+
557
+ # noise sample
558
+ x = self.q_sample(x_start=x_start, t=t, noise=noise) # Use q_sample here for updating: https://github.com/lucidrains/denoising-diffusion-pytorch/blob/main/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L763
559
+
560
+ model_out = self.diff_unet(x, t, conds=cond)
561
+
562
+ if self.objective == 'pred_noise':
563
+ target = noise
564
+ elif self.objective == 'pred_x0':
565
+ target = x_start
566
+ elif self.objective == 'pred_v':
567
+ v = self.predict_v(x_start, t, noise)
568
+ target = v
569
+ else:
570
+ raise ValueError(f'unknown objective {self.objective}')
571
+
572
+ loss = F.mse_loss(model_out, target, reduction = 'none') # (B, T, C, H, W)
573
+ loss = reduce(loss, 'b ... -> b', 'mean')
574
+
575
+ loss = loss * extract(self.loss_weight, t, loss.shape)
576
+ return loss.mean()
577
+
578
+ @torch.no_grad()
579
+ def forward(self, input_x, include_mu=False, **kwargs):
580
+ pred, mu = self.predict(input_x, compute_loss=False)
581
+ if include_mu:
582
+ return pred, mu
583
+ else:
584
+ return pred
585
+
586
+ from stldm.modules import SimVPV2_Model, VAE
587
+ def model_setup(model_config, print_info=False, cfg_str=None):
588
+ if print_info:
589
+ print('Setup the model with considering temporal attention be (BHW, T, C) ... ...')
590
+ print('Train it from end to end')
591
+ vp_config = model_config['vp_param']
592
+ ldm_config = model_config['stldm_param']
593
+
594
+ vpm = SimVPV2_Model(**vp_config)
595
+ ldm = LDM(**ldm_config)
596
+ model = GaussianDiffusion(vp_model=vpm, diffusion=ldm, **model_config['param'])
597
+
598
+ scheduler = guidance_scheduler(sampling_step=model_config['param']['timesteps'], const=cfg_str) if cfg_str is not None else None
599
+ model.setup_guidance(scheduler)
600
+
601
+ return model
602
+
603
+ def ae_setup(model_config):
604
+ vp_config = model_config['vp_param']
605
+ vpm = SimVPV2_Model(**vp_config)
606
+ ae = vpm.vae
607
+ return ae
608
+
609
+ def backbone_setup(model_config):
610
+ vp_config = model_config['vp_param']
611
+ vpm = SimVPV2_Model(**vp_config)
612
+ return vpm
stldm/stldm_hf.py ADDED
@@ -0,0 +1,620 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, random
2
+ from torch import nn
3
+ from einops import rearrange
4
+
5
+ from stldm.submodules import *
6
+
7
+ class Down_Block(nn.Module):
8
+ def __init__(self, in_ch, hid_ch, out_ch, time_dim, is_last, patch_size=None, num_groups=8, heads=4, dim_head=32):
9
+ super(Down_Block, self).__init__()
10
+ self.block1 = ResnetBlock(dim=in_ch, dim_out=hid_ch, time_emb_dim=time_dim, groups=num_groups)
11
+ self.attn_spatial = Residual(PreNorm(hid_ch, Quadratic_SpatialAttention(dim=hid_ch, heads=heads, dim_head=dim_head))) if patch_size is None else Residual(PreNorm(hid_ch, Linear_SpatialAttention(dim=hid_ch, patch_size=patch_size, heads=heads, dim_head=dim_head)))
12
+ self.block2 = ResnetBlock(dim=hid_ch, dim_out=hid_ch, groups=num_groups)
13
+ # self.attn_temporal = Residual(PreNorm(hid_ch, TemporalAttention_Pos(dim=hid_ch, heads=heads, dim_head=dim_head)))
14
+ self.attn_temporal = Residual(PreNorm(hid_ch, TemporalAttention(dim=hid_ch, heads=heads, dim_head=dim_head)))
15
+ self.last = Downsample2D(dim_in=hid_ch, dim_out=out_ch) if not is_last else ChannelConversion(hid_ch, out_ch)
16
+
17
+ def forward(self, x, time_emb, cond=None, relative_pos=None):
18
+ assert x.ndim==5
19
+ B, T, C, H, W = x.shape
20
+
21
+ x = x.reshape(B*T, C, H, W)
22
+ if cond is None:
23
+ cond = torch.zeros_like(x) # -> Unconditioning
24
+
25
+ time_emb = time_emb.unsqueeze(1) # From (B C) to (B 1 C)
26
+ time_emb = time_emb.repeat(1, T, 1)
27
+ time_emb = time_emb.reshape(B*T, -1)
28
+
29
+ out = torch.cat((x, cond), dim=1) # BT, 2C, H, W
30
+ out = self.block1(out, time_emb)
31
+
32
+ spatial_attn = self.attn_spatial(out)
33
+ out = self.block2(spatial_attn, time_emb)
34
+ *_, c, h, w = out.shape
35
+ out = out.reshape(B,T,c,h,w)
36
+
37
+ # temporal_attn = self.attn_temporal(out, relative_pos)
38
+ temporal_attn = self.attn_temporal(out)
39
+ temporal_attn = temporal_attn.reshape(B*T,c,h,w)
40
+
41
+ out = self.last(temporal_attn)
42
+ *_, c, h, w = out.shape
43
+
44
+ return out.reshape(B, T, c, h, w), spatial_attn, temporal_attn
45
+
46
+ class MidBlock(nn.Module):
47
+ def __init__(self, in_ch, time_dim, num_groups=8, heads=4, dim_head=32):
48
+ super(MidBlock, self).__init__()
49
+ self.block1 = ResnetBlock(dim=in_ch, dim_out=in_ch, time_emb_dim=time_dim, groups=num_groups)
50
+ self.qattn_spatial = Residual(PreNorm(in_ch, Quadratic_SpatialAttention(dim=in_ch, heads=heads, dim_head=dim_head)))
51
+ self.block2 = ResnetBlock(dim=in_ch, dim_out=in_ch, time_emb_dim=time_dim, groups=num_groups)
52
+ # self.qattn_time = Residual(PreNorm(in_ch, TemporalAttention_Pos(dim=in_ch, heads=heads, dim_head=dim_head)))
53
+ self.qattn_time = Residual(PreNorm(in_ch, TemporalAttention(dim=in_ch, heads=heads, dim_head=dim_head)))
54
+ self.block3 = ResnetBlock(dim=in_ch, dim_out=in_ch, time_emb_dim=time_dim, groups=num_groups)
55
+
56
+ def forward(self, x, time_emb, relative_pos=None):
57
+ assert x.ndim==5
58
+ B, T, C, H, W = x.shape
59
+ x = x.reshape(B*T, C, H, W)
60
+
61
+ time_emb = time_emb.unsqueeze(1) # From (B C) to (B 1 C)
62
+ time_emb = time_emb.repeat(1, T, 1)
63
+ time_emb = time_emb.reshape(B*T, -1)
64
+
65
+ out = self.block1(x, time_emb)
66
+ out = self.qattn_spatial(out)
67
+ out = self.block2(out, time_emb) # a little bit difference here
68
+
69
+ out = out.reshape((B, T, C, H, W))
70
+ # out = self.qattn_time(out, relative_pos).reshape(B*T, C, H, W)
71
+ out = self.qattn_time(out).reshape(B*T, C, H, W)
72
+ out = self.block3(out, time_emb)
73
+
74
+ *_, c, h, w = out.shape
75
+ return out.reshape(B, T, c, h, w)
76
+
77
+ class Up_Block(nn.Module):
78
+ def __init__(self, in_chs, hid_ch, out_ch, is_last, time_dim, patch_size=None, num_groups=8, heads=4, dim_head=32):
79
+ super(Up_Block, self).__init__()
80
+ in_ch, skip_ch = in_chs
81
+ self.up = Upsample2D(dim_in=in_ch, dim_out=hid_ch) if not is_last else ChannelConversion(in_ch, hid_ch)
82
+ self.attn_spatial = Residual(PreNorm(hid_ch, Quadratic_SpatialAttention(dim=hid_ch, heads=heads, dim_head=dim_head) if patch_size is None else Linear_SpatialAttention(dim=hid_ch, patch_size=patch_size, heads=heads, dim_head=dim_head)))
83
+ self.block1 = ResnetBlock(dim=hid_ch+skip_ch, dim_out=hid_ch, time_emb_dim=time_dim, groups=num_groups)
84
+ # self.attn_temporal = Residual(PreNorm(hid_ch, TemporalAttention_Pos(dim=hid_ch, heads=heads, dim_head=dim_head)))
85
+ self.attn_temporal = Residual(PreNorm(hid_ch, TemporalAttention(dim=hid_ch, heads=heads, dim_head=dim_head)))
86
+ self.block2 = ResnetBlock(dim=hid_ch+skip_ch, dim_out=out_ch, time_emb_dim=time_dim, groups=num_groups)
87
+
88
+ def forward(self, x, time_emb, spatialattn_skip, tempattn_skip, relative_pos=None):
89
+ assert x.ndim==5
90
+ B, T, C, H, W = x.shape
91
+ x = x.reshape(B*T, C, H, W)
92
+
93
+ time_emb = time_emb.unsqueeze(1) # From (B C) to (B 1 C)
94
+ time_emb = time_emb.repeat(1, T, 1)
95
+ time_emb = time_emb.reshape(B*T, -1)
96
+
97
+ out = self.up(x)
98
+ *_, c, h, w = out.shape
99
+ out = out.reshape(-1, T, c, h, w)
100
+
101
+ # out = self.attn_temporal(out, relative_pos).reshape(B*T, c, h, w)
102
+ out = self.attn_temporal(out).reshape(B*T, c, h, w)
103
+
104
+ out = torch.cat((out, tempattn_skip), dim=1)
105
+ out = self.block1(out, time_emb)
106
+
107
+ out = self.attn_spatial(out)
108
+
109
+ out = torch.cat((out, spatialattn_skip), dim=1)
110
+ out = self.block2(out, time_emb)
111
+ *_, c, h, w = out.shape
112
+ return out.reshape(B, T, c, h, w)
113
+
114
+ class LDM(nn.Module):
115
+ def __init__(self, in_ch, chs_mult:tuple, patch_size=None, num_groups=8, heads=4, dim_head=32, base_ch=64):
116
+ super(LDM, self).__init__()
117
+ # Time Embedding MLP
118
+ time_dim = 4*base_ch
119
+ fourier_dim = base_ch
120
+ self.time_mlp = Time_MLP(dim=base_ch, time_dim=time_dim, fourier_dim=fourier_dim)
121
+
122
+ ups, downs = [], []
123
+ conditions = []
124
+
125
+ layer_no = len(chs_mult)
126
+ chs = [in_ch, *map(lambda m: base_ch*m, chs_mult)]
127
+ ch_in, ch_out = chs[:-1], chs[1:]
128
+ up_in, up_out = list(reversed(ch_out)), list(reversed(ch_in))
129
+
130
+ patches = None if patch_size is None else [patch_size//(2**n) for n in range(layer_no)] # Patch Size should be 2^N
131
+ for n in range(layer_no):
132
+ downs.append(
133
+ Down_Block(in_ch=2*ch_in[n], hid_ch=ch_in[n], out_ch=ch_out[n], time_dim=time_dim, patch_size=None if patch_size is None else patches[n], is_last=(n==layer_no-1), num_groups=num_groups, heads=heads, dim_head=dim_head)
134
+ )
135
+ ups.append(
136
+ Up_Block(in_chs=(up_in[n], ch_in[-n-1]), hid_ch=up_in[n], out_ch=up_out[n], time_dim=time_dim, patch_size=None if patch_size is None else patches[layer_no-n-1], is_last=(n==0), num_groups=num_groups, heads=heads, dim_head=dim_head)
137
+ )
138
+ if n != -1:
139
+ conditions.append(
140
+ Downsample2D(dim_in=ch_in[n], dim_out=ch_out[n])
141
+ )
142
+
143
+ self.downs = nn.ModuleList(downs)
144
+ self.ups = nn.ModuleList(ups)
145
+ self.conditions = nn.ModuleList(conditions)
146
+ self.mid = MidBlock(in_ch=ch_out[-1], time_dim=time_dim, num_groups=num_groups, heads=heads, dim_head=dim_head)
147
+ # self.relative_pos = RelativePositionBias(heads=heads)
148
+
149
+ def forward(self, x, time, conds=None):
150
+ t = self.time_mlp(time)
151
+
152
+ hid_spatial = []
153
+ hid_temporal = []
154
+
155
+ # relative_position = self.relative_pos(x.shape[1], x.device) # Calculate The Relative Position
156
+
157
+ for n, down_block in enumerate(self.downs):
158
+ # print(x.shape)
159
+ # x, spatial_attn, time_attn = down_block(x, t, conds, relative_position)
160
+ x, spatial_attn, time_attn = down_block(x, t, conds)
161
+ hid_spatial.append(spatial_attn)
162
+ hid_temporal.append(time_attn)
163
+ if conds is not None:
164
+ conds = self.conditions[n](conds)
165
+
166
+ # out = self.mid(x, t, relative_position)
167
+ out = self.mid(x, t)
168
+
169
+ for up_block in self.ups:
170
+ # out = up_block(out, t, hid_spatial.pop(), hid_temporal.pop(), relative_position)
171
+ out = up_block(out, t, hid_spatial.pop(), hid_temporal.pop())
172
+
173
+ return out
174
+
175
+ # constants
176
+ from collections import namedtuple
177
+ from torch.cuda.amp import autocast
178
+ import torch.nn.functional as F
179
+ from einops import reduce
180
+ from tqdm.auto import tqdm
181
+
182
+ ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start'])
183
+
184
+ def identity(t, *args, **kwargs):
185
+ return t
186
+
187
+ def extract(a, t, x_shape):
188
+ b, *_ = t.shape
189
+ out = a.gather(-1, t)
190
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
191
+
192
+ def default(val, d):
193
+ if exists(val):
194
+ return val
195
+ return d() if callable(d) else d
196
+
197
+ def exists(x):
198
+ return x is not None
199
+
200
+ def guidance_scheduler(sampling_step: int, const: float):
201
+ return const*torch.ones(sampling_step)
202
+
203
+ from huggingface_hub import PyTorchModelHubMixin
204
+
205
+ class GaussianDiffusion(
206
+ nn.Module,
207
+ PyTorchModelHubMixin,
208
+ # optionally, you can add metadata which gets pushed to the model card
209
+ repo_url="https://github.com/sqfoo/stldm_official",
210
+ pipeline_tag="Precipitation_Nowcasting",
211
+ license="mit"):
212
+ def __init__(
213
+ self,
214
+ vp_param: dict,
215
+ stldm_param: dict,
216
+ timesteps = 1000,
217
+ sampling_timesteps = None,
218
+ objective = 'pred_v',
219
+ beta_schedule = 'sigmoid',
220
+ schedule_fn_kwargs = dict(),
221
+ ddim_sampling_eta = 0.,
222
+ offset_noise_strength = 0., # https://www.crosslabs.org/blog/diffusion-with-offset-noise
223
+ min_snr_loss_weight = False, # https://arxiv.org/abs/2303.09556
224
+ min_snr_gamma = 5
225
+ ):
226
+ super(GaussianDiffusion, self).__init__()
227
+
228
+ self.backbone = SimVPV2_Model(**vp_param)
229
+ self.diff_unet = LDM(**stldm_param)
230
+
231
+ self.objective = objective
232
+ assert objective in {'pred_noise', 'pred_x0', 'pred_v'}, 'objective must be either pred_noise (predict noise) or pred_x0 (predict image start) or pred_v (predict v [v-parameterization as defined in appendix D of progressive distillation paper, used in imagen-video successfully])'
233
+
234
+ if beta_schedule == 'linear':
235
+ beta_schedule_fn = linear_beta_schedule
236
+ elif beta_schedule == 'cosine':
237
+ beta_schedule_fn = cosine_beta_schedule
238
+ elif beta_schedule == 'sigmoid':
239
+ beta_schedule_fn = sigmoid_beta_schedule
240
+ else:
241
+ raise ValueError(f'unknown beta schedule {beta_schedule}')
242
+
243
+ betas = beta_schedule_fn(timesteps, **schedule_fn_kwargs)
244
+
245
+ alphas = 1. - betas
246
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
247
+ alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
248
+
249
+ timesteps, = betas.shape
250
+ self.num_timesteps = int(timesteps)
251
+
252
+ # sampling related parameters
253
+
254
+ self.sampling_timesteps = default(sampling_timesteps, timesteps) # default num sampling timesteps to number of timesteps at training
255
+
256
+ assert self.sampling_timesteps <= timesteps
257
+ self.is_ddim_sampling = self.sampling_timesteps < timesteps
258
+ self.ddim_sampling_eta = ddim_sampling_eta
259
+
260
+ # helper function to register buffer from float64 to float32
261
+
262
+ register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
263
+
264
+ register_buffer('betas', betas)
265
+ register_buffer('alphas_cumprod', alphas_cumprod)
266
+ register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
267
+
268
+ # calculations for diffusion q(x_t | x_{t-1}) and others
269
+
270
+ register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
271
+ register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
272
+ register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
273
+ register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
274
+ register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
275
+
276
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
277
+
278
+ posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
279
+
280
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
281
+
282
+ register_buffer('posterior_variance', posterior_variance)
283
+
284
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
285
+
286
+ register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
287
+ register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
288
+ register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
289
+
290
+ # offset noise strength - in blogpost, they claimed 0.1 was ideal
291
+
292
+ self.offset_noise_strength = offset_noise_strength
293
+
294
+ # derive loss weight
295
+ # snr - signal noise ratio
296
+
297
+ snr = alphas_cumprod / (1 - alphas_cumprod)
298
+
299
+ # https://arxiv.org/abs/2303.09556
300
+
301
+ maybe_clipped_snr = snr.clone()
302
+ if min_snr_loss_weight:
303
+ maybe_clipped_snr.clamp_(max = min_snr_gamma)
304
+
305
+ if objective == 'pred_noise':
306
+ register_buffer('loss_weight', maybe_clipped_snr / snr)
307
+ elif objective == 'pred_x0':
308
+ register_buffer('loss_weight', maybe_clipped_snr)
309
+ elif objective == 'pred_v':
310
+ register_buffer('loss_weight', maybe_clipped_snr / (snr + 1))
311
+
312
+ @property
313
+ def device(self):
314
+ return self.betas.device
315
+
316
+ # CFG schdeuler => by taking pre-setting scheduler
317
+ def setup_guidance(self, scheduler):
318
+ if exists(scheduler):
319
+ self.CFG_sch = scheduler.to(self.device)
320
+ else:
321
+ self.CFG_sch = scheduler
322
+
323
+ def predict_start_from_noise(self, x_t, t, noise):
324
+ return (
325
+ extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
326
+ extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
327
+ )
328
+
329
+ def predict_noise_from_start(self, x_t, t, x0):
330
+ return (
331
+ (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \
332
+ extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
333
+ )
334
+
335
+ def predict_v(self, x_start, t, noise):
336
+ return (
337
+ extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise -
338
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start
339
+ )
340
+
341
+ def predict_start_from_v(self, x_t, t, v):
342
+ return (
343
+ extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
344
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
345
+ )
346
+
347
+ def q_posterior(self, x_start, x_t, t):
348
+ posterior_mean = (
349
+ extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
350
+ extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
351
+ )
352
+ posterior_variance = extract(self.posterior_variance, t, x_t.shape)
353
+ posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
354
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
355
+
356
+ def model_predictions(self, x, t, cond, clip_x_start = False, rederive_pred_noise = False):
357
+ # print(t.device)
358
+ if exists(self.CFG_sch):
359
+ uncond = self.diff_unet(x, t, conds=None) #conds=torch.zeros_like(cond))
360
+ model_output = self.diff_unet(x, t, conds=cond)
361
+ time = int(t[0])
362
+ model_output = model_output - self.CFG_sch[time] * (uncond - model_output)
363
+ else:
364
+ model_output = self.diff_unet(x, t, conds=cond)
365
+ maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity
366
+
367
+ if self.objective == 'pred_noise':
368
+ pred_noise = model_output
369
+ x_start = self.predict_start_from_noise(x, t, pred_noise)
370
+ x_start = maybe_clip(x_start)
371
+
372
+ if clip_x_start and rederive_pred_noise:
373
+ pred_noise = self.predict_noise_from_start(x, t, x_start)
374
+
375
+ elif self.objective == 'pred_x0':
376
+ x_start = model_output
377
+ x_start = maybe_clip(x_start)
378
+ pred_noise = self.predict_noise_from_start(x, t, x_start)
379
+
380
+ elif self.objective == 'pred_v':
381
+ v = model_output
382
+ x_start = self.predict_start_from_v(x, t, v)
383
+ x_start = maybe_clip(x_start)
384
+ pred_noise = self.predict_noise_from_start(x, t, x_start)
385
+
386
+ return ModelPrediction(pred_noise, x_start)
387
+
388
+ def p_mean_variance(self, x, t, cond=None, clip_denoised = True):
389
+ preds = self.model_predictions(x, t, cond=cond, clip_x_start=False,)
390
+ x_start = preds.pred_x_start
391
+
392
+ if clip_denoised:
393
+ x_start.clamp_(-1., 1.)
394
+
395
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t)
396
+ return model_mean, posterior_variance, posterior_log_variance, x_start
397
+
398
+ @torch.no_grad()
399
+ def p_sample(self, x, t: int, cond=None):
400
+ b, *_, device = *x.shape, self.device
401
+ batched_times = torch.full((b,), t, device = device, dtype = torch.long)
402
+ model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, cond=cond, clip_denoised = False)
403
+ noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0
404
+ pred_img = model_mean + (0.5 * model_log_variance).exp() * noise
405
+ return pred_img, x_start
406
+
407
+ @torch.no_grad()
408
+ def p_sample_loop(self, shape, cond=None, return_all_timesteps = False):
409
+ batch, device = shape[0], self.device
410
+
411
+ frames_pred = torch.randn(shape, device = device)
412
+ imgs = [frames_pred]
413
+
414
+ for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps, disable=True):
415
+ frames_pred, _ = self.p_sample(frames_pred, t, cond=cond)
416
+ imgs.append(frames_pred)
417
+
418
+ ret = frames_pred if not return_all_timesteps else torch.stack(imgs, dim = 1)
419
+ return ret
420
+
421
+ @torch.no_grad()
422
+ def ddim_sample(self, shape, cond=None, return_all_timesteps = False):
423
+ batch, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective
424
+ device = self.device
425
+ times = torch.linspace(-1, total_timesteps - 1, steps = sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
426
+ times = list(reversed(times.int().tolist()))
427
+ time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
428
+
429
+ frames_pred = torch.randn(shape, device = device)
430
+ imgs = [frames_pred]
431
+
432
+ for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step', disable=True):
433
+ time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
434
+ pred_noise, x_start, *_ = self.model_predictions(
435
+ frames_pred,
436
+ time_cond,
437
+ cond = cond, #cond.copy(),
438
+ clip_x_start = False,
439
+ rederive_pred_noise = True
440
+ )
441
+
442
+ if time_next < 0:
443
+ frames_pred = x_start
444
+ imgs.append(frames_pred)
445
+ continue
446
+
447
+ alpha = self.alphas_cumprod[time]
448
+ alpha_next = self.alphas_cumprod[time_next]
449
+
450
+ sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
451
+ c = (1 - alpha_next - sigma ** 2).sqrt()
452
+
453
+ noise = torch.randn_like(frames_pred)
454
+
455
+ frames_pred = x_start * alpha_next.sqrt() + \
456
+ c * pred_noise + \
457
+ sigma * noise
458
+
459
+ imgs.append(frames_pred)
460
+
461
+ ret = frames_pred if not return_all_timesteps else torch.stack(imgs, dim = 1)
462
+ return ret
463
+
464
+ @torch.no_grad()
465
+ def sample(self, frames_in, return_all_timesteps = False):
466
+ assert frames_in.ndim == 5
467
+ B, T_in, C, H, W = frames_in.shape
468
+ device = self.device
469
+
470
+ backbone_output, conds, *_ = self.backbone(frames_in)
471
+ sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample
472
+
473
+ *_, c, h, w = conds.shape
474
+ tgt_shape = conds.reshape(B, -1, c, h, w).shape
475
+ ldm_pred = sample_fn(
476
+ tgt_shape,
477
+ cond=conds,
478
+ return_all_timesteps = return_all_timesteps
479
+ )
480
+
481
+ ldm_pred = rearrange(ldm_pred, 'b t c h w -> (b t) c h w')
482
+ frames_pred = self.backbone.vae.decode(ldm_pred)
483
+ frames_pred = rearrange(frames_pred, '(b t) c h w -> b t c h w', b=B)
484
+ return frames_pred, backbone_output
485
+
486
+ def predict(self, frames_in, compute_loss=False, **kwargs):
487
+ pred, mu = self.sample(frames_in=frames_in)
488
+ return pred, mu
489
+
490
+ def compute_loss(self, frames_in, frames_gt, validate=False):
491
+ compute_loss = True and (not validate)
492
+ B, T_in, C, H, W = frames_in.shape
493
+ T_out = frames_gt.shape[1]
494
+ device = frames_in.device
495
+
496
+ """
497
+ Diffusion Loss
498
+ """
499
+ backbone_output, conds = self.backbone(frames_in)
500
+ hid_gt, _ = self.backbone.vae.encode(
501
+ rearrange(frames_gt, 'b t c h w -> (b t) c h w')
502
+ )
503
+ hid_gt = rearrange(hid_gt, '(b t) c h w -> b t c h w', b=B)
504
+ t = torch.randint(0, self.num_timesteps, (B,), device=self.device).long()
505
+ if random.random() > 0.85: # Unconditional
506
+ conds = None
507
+ diff_loss = self.p_losses(hid_gt.detach(), t, cond=conds)
508
+
509
+ """
510
+ Backbone Loss
511
+ """
512
+ mu_loss = self.backbone._losses_(frames_in, frames_gt)
513
+
514
+ """
515
+ VAE Loss
516
+ """
517
+ ae_loss, kl_loss = self.backbone.vae._losses_(
518
+ rearrange(torch.cat((frames_in, frames_gt), dim=1), 'b t c h w -> (b t) c h w'),
519
+ rearrange(torch.cat((frames_in, frames_gt), dim=1), 'b t c h w -> (b t) c h w')
520
+ )
521
+ kl_weight = 1E-6
522
+ recon_loss = ae_loss + kl_weight*kl_loss
523
+
524
+ """
525
+ Prior Loss at t=T [Noisy]
526
+ """
527
+ hid_gt, _ = self.backbone.vae.encode(
528
+ rearrange(frames_gt, 'b t c h w -> (b t) c h w')
529
+ )
530
+ hid_gt = rearrange(hid_gt, '(b t) c h w -> b t c h w', b=B)
531
+ T = torch.ones((B,), device=self.device).long() * (self.num_timesteps - 1)
532
+ mu_noisy = extract(self.sqrt_alphas_cumprod, T, hid_gt.shape) * hid_gt
533
+ sigma_noisy = extract(self.sqrt_one_minus_alphas_cumprod, T, hid_gt.shape)
534
+ log_var_noisy = 2*torch.log(sigma_noisy)
535
+ prior_loss = self.kl_from_standard_normal(mu_noisy, log_var_noisy)
536
+
537
+ return recon_loss, mu_loss, diff_loss, prior_loss
538
+
539
+
540
+ def kl_from_standard_normal(self, mean, log_var):
541
+ kl = 0.5 * (log_var.exp() + mean.square() - 1.0 - log_var)
542
+ return kl.mean()
543
+
544
+ @autocast(enabled = False)
545
+ def q_sample(self, x_start, t, noise = None):
546
+ noise = default(noise, lambda: torch.randn_like(x_start))
547
+
548
+ return (
549
+ extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
550
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
551
+ )
552
+
553
+ def p_losses(self, x_start, t, noise=None, offset_noise_strength=None, cond=None):
554
+ b, T, c, h, w = x_start.shape
555
+
556
+ noise = default(noise, lambda: torch.randn_like(x_start))
557
+
558
+ # offset noise - https://www.crosslabs.org/blog/diffusion-with-offset-noise
559
+ offset_noise_strength = default(offset_noise_strength, self.offset_noise_strength)
560
+
561
+ if offset_noise_strength > 0.:
562
+ offset_noise = torch.randn(x_start.shape[:2], device = self.device)
563
+ noise += offset_noise_strength * rearrange(offset_noise, 'b c -> b c 1 1')
564
+
565
+ # noise sample
566
+ x = self.q_sample(x_start=x_start, t=t, noise=noise) # Use q_sample here for updating: https://github.com/lucidrains/denoising-diffusion-pytorch/blob/main/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L763
567
+
568
+ model_out = self.diff_unet(x, t, conds=cond)
569
+
570
+ if self.objective == 'pred_noise':
571
+ target = noise
572
+ elif self.objective == 'pred_x0':
573
+ target = x_start
574
+ elif self.objective == 'pred_v':
575
+ v = self.predict_v(x_start, t, noise)
576
+ target = v
577
+ else:
578
+ raise ValueError(f'unknown objective {self.objective}')
579
+
580
+ loss = F.mse_loss(model_out, target, reduction = 'none') # (B, T, C, H, W)
581
+ loss = reduce(loss, 'b ... -> b', 'mean')
582
+
583
+ loss = loss * extract(self.loss_weight, t, loss.shape)
584
+ return loss.mean()
585
+
586
+ @torch.no_grad()
587
+ def forward(self, input_x, include_mu=False, **kwargs):
588
+ pred, mu = self.predict(input_x, compute_loss=False)
589
+ if include_mu:
590
+ return pred, mu
591
+ else:
592
+ return pred
593
+
594
+ from stldm.modules import SimVPV2_Model, VAE
595
+ def model_setup(model_config, print_info=False, cfg_str=None):
596
+ if print_info:
597
+ print('Setup the model with considering temporal attention be (BHW, T, C) ... ...')
598
+ print('Train it from end to end')
599
+ vp_config = model_config['vp_param']
600
+ ldm_config = model_config['stldm_param']
601
+
602
+ vpm = SimVPV2_Model(**vp_config)
603
+ ldm = LDM(**ldm_config)
604
+ model = GaussianDiffusion(vp_model=vpm, diffusion=ldm, **model_config['param'])
605
+
606
+ scheduler = guidance_scheduler(sampling_step=model_config['param']['timesteps'], const=cfg_str) if cfg_str is not None else None
607
+ model.setup_guidance(scheduler)
608
+
609
+ return model
610
+
611
+ def ae_setup(model_config):
612
+ vp_config = model_config['vp_param']
613
+ vpm = SimVPV2_Model(**vp_config)
614
+ ae = vpm.vae
615
+ return ae
616
+
617
+ def backbone_setup(model_config):
618
+ vp_config = model_config['vp_param']
619
+ vpm = SimVPV2_Model(**vp_config)
620
+ return vpm
stldm/stldm_spatial.py ADDED
@@ -0,0 +1,593 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, random
2
+ from torch import nn
3
+ from einops import rearrange
4
+
5
+ from stldm.submodules import *
6
+
7
+ class Down_Block(nn.Module):
8
+ def __init__(self, in_ch, hid_ch, out_ch, time_dim, is_last, patch_size=None, num_groups=8, heads=4, dim_head=32):
9
+ super(Down_Block, self).__init__()
10
+ self.block1 = ResnetBlock(dim=in_ch, dim_out=hid_ch, time_emb_dim=time_dim, groups=num_groups)
11
+ self.attn1 = Residual(PreNorm(hid_ch, Quadratic_SpatialAttention(dim=hid_ch, heads=heads, dim_head=dim_head))) if patch_size is None else Residual(PreNorm(hid_ch, Linear_SpatialAttention(dim=hid_ch, patch_size=patch_size, heads=heads, dim_head=dim_head)))
12
+ self.block2 = ResnetBlock(dim=hid_ch, dim_out=hid_ch, groups=num_groups)
13
+ self.attn2 = nn.Identity()
14
+ # self.attn2 = Residual(PreNorm(hid_ch, Quadratic_SpatialAttention(dim=hid_ch, heads=heads, dim_head=dim_head))) if patch_size is None else Residual(PreNorm(hid_ch, Linear_SpatialAttention(dim=hid_ch, patch_size=patch_size, heads=heads, dim_head=dim_head)))
15
+ self.last = Downsample2D(dim_in=hid_ch, dim_out=out_ch) if not is_last else ChannelConversion(hid_ch, out_ch)
16
+
17
+ def forward(self, x, time_emb, cond=None):
18
+ assert x.ndim==5
19
+ B, T, C, H, W = x.shape
20
+
21
+ x = x.reshape(B*T, C, H, W)
22
+ if cond is None:
23
+ cond = torch.zeros_like(x) # -> Unconditioning
24
+
25
+ time_emb = time_emb.unsqueeze(1) # From (B C) to (B 1 C)
26
+ time_emb = time_emb.repeat(1, T, 1)
27
+ time_emb = time_emb.reshape(B*T, -1)
28
+
29
+ out = torch.cat((x, cond), dim=1) # BT, 2C, H, W
30
+ out = self.block1(out, time_emb)
31
+
32
+ skip1 = self.attn1(out)
33
+ out = self.block2(skip1, time_emb)
34
+
35
+ skip2 = self.attn2(out)
36
+
37
+ out = self.last(skip2)
38
+ *_, c, h, w = out.shape
39
+
40
+ return out.reshape(B, T, c, h, w), skip1, skip2
41
+
42
+ class MidBlock(nn.Module):
43
+ def __init__(self, in_ch, time_dim, num_groups=8, heads=4, dim_head=32):
44
+ super(MidBlock, self).__init__()
45
+ self.block1 = ResnetBlock(dim=in_ch, dim_out=in_ch, time_emb_dim=time_dim, groups=num_groups)
46
+ self.attn1 = Residual(PreNorm(in_ch, Quadratic_SpatialAttention(dim=in_ch, heads=heads, dim_head=dim_head)))
47
+ self.block2 = ResnetBlock(dim=in_ch, dim_out=in_ch, time_emb_dim=time_dim, groups=num_groups)
48
+ self.attn2 = nn.Identity()
49
+ # self.attn2 = Residual(PreNorm(in_ch, Quadratic_SpatialAttention(dim=in_ch, heads=heads, dim_head=dim_head)))
50
+ self.block3 = ResnetBlock(dim=in_ch, dim_out=in_ch, time_emb_dim=time_dim, groups=num_groups)
51
+
52
+ def forward(self, x, time_emb, relative_pos=None):
53
+ assert x.ndim==5
54
+ B, T, C, H, W = x.shape
55
+ x = x.reshape(B*T, C, H, W)
56
+
57
+ time_emb = time_emb.unsqueeze(1) # From (B C) to (B 1 C)
58
+ time_emb = time_emb.repeat(1, T, 1)
59
+ time_emb = time_emb.reshape(B*T, -1)
60
+
61
+ out = self.block1(x, time_emb)
62
+ out = self.attn1(out)
63
+ out = self.block2(out, time_emb) # a little bit difference here
64
+ out = self.attn2(out)
65
+ out = self.block3(out, time_emb)
66
+
67
+ *_, c, h, w = out.shape
68
+ return out.reshape(B, T, c, h, w)
69
+
70
+ class Up_Block(nn.Module):
71
+ def __init__(self, in_chs, hid_ch, out_ch, is_last, time_dim, patch_size=None, num_groups=8, heads=4, dim_head=32):
72
+ super(Up_Block, self).__init__()
73
+ in_ch, skip_ch = in_chs
74
+ self.up = Upsample2D(dim_in=in_ch, dim_out=hid_ch) if not is_last else ChannelConversion(in_ch, hid_ch)
75
+ self.attn1 = Residual(PreNorm(hid_ch, Quadratic_SpatialAttention(dim=hid_ch, heads=heads, dim_head=dim_head) if patch_size is None else Linear_SpatialAttention(dim=hid_ch, patch_size=patch_size, heads=heads, dim_head=dim_head)))
76
+ self.block1 = ResnetBlock(dim=hid_ch+skip_ch, dim_out=hid_ch, time_emb_dim=time_dim, groups=num_groups)
77
+ self.attn2 = nn.Identity()
78
+ # self.attn2 = Residual(PreNorm(hid_ch, Quadratic_SpatialAttention(dim=hid_ch, heads=heads, dim_head=dim_head) if patch_size is None else Linear_SpatialAttention(dim=hid_ch, patch_size=patch_size, heads=heads, dim_head=dim_head)))
79
+ self.block2 = ResnetBlock(dim=hid_ch+skip_ch, dim_out=out_ch, time_emb_dim=time_dim, groups=num_groups)
80
+
81
+ def forward(self, x, time_emb, skip1, skip2):
82
+ assert x.ndim==5
83
+ B, T, C, H, W = x.shape
84
+ x = x.reshape(B*T, C, H, W)
85
+
86
+ time_emb = time_emb.unsqueeze(1) # From (B C) to (B 1 C)
87
+ time_emb = time_emb.repeat(1, T, 1)
88
+ time_emb = time_emb.reshape(B*T, -1)
89
+
90
+ out = self.up(x)
91
+ *_, c, h, w = out.shape
92
+ out = self.attn1(out)
93
+
94
+ out = torch.cat((out, skip2), dim=1)
95
+ out = self.block1(out, time_emb)
96
+
97
+ out = self.attn2(out)
98
+
99
+ out = torch.cat((out, skip2), dim=1)
100
+ out = self.block2(out, time_emb)
101
+ *_, c, h, w = out.shape
102
+ return out.reshape(B, T, c, h, w)
103
+
104
+ class LDM(nn.Module):
105
+ def __init__(self, in_ch, chs_mult:tuple, patch_size=None, num_groups=8, heads=4, dim_head=32, base_ch=64):
106
+ super(LDM, self).__init__()
107
+ # Time Embedding MLP
108
+ time_dim = 4*base_ch
109
+ fourier_dim = base_ch
110
+ self.time_mlp = Time_MLP(dim=base_ch, time_dim=time_dim, fourier_dim=fourier_dim)
111
+
112
+ ups, downs = [], []
113
+ conditions = []
114
+
115
+ layer_no = len(chs_mult)
116
+ chs = [in_ch, *map(lambda m: base_ch*m, chs_mult)]
117
+ ch_in, ch_out = chs[:-1], chs[1:]
118
+ up_in, up_out = list(reversed(ch_out)), list(reversed(ch_in))
119
+
120
+ patches = None if patch_size is None else [patch_size//(2**n) for n in range(layer_no)] # Patch Size should be 2^N
121
+ for n in range(layer_no):
122
+ downs.append(
123
+ Down_Block(in_ch=2*ch_in[n], hid_ch=ch_in[n], out_ch=ch_out[n], time_dim=time_dim, patch_size=None if patch_size is None else patches[n], is_last=(n==layer_no-1), num_groups=num_groups, heads=heads, dim_head=dim_head)
124
+ )
125
+ ups.append(
126
+ Up_Block(in_chs=(up_in[n], ch_in[-n-1]), hid_ch=up_in[n], out_ch=up_out[n], time_dim=time_dim, patch_size=None if patch_size is None else patches[layer_no-n-1], is_last=(n==0), num_groups=num_groups, heads=heads, dim_head=dim_head)
127
+ )
128
+ if n != -1:
129
+ conditions.append(
130
+ Downsample2D(dim_in=ch_in[n], dim_out=ch_out[n])
131
+ )
132
+
133
+ self.downs = nn.ModuleList(downs)
134
+ self.ups = nn.ModuleList(ups)
135
+ self.conditions = nn.ModuleList(conditions)
136
+ self.mid = MidBlock(in_ch=ch_out[-1], time_dim=time_dim, num_groups=num_groups, heads=heads, dim_head=dim_head)
137
+
138
+ def forward(self, x, time, conds=None):
139
+ t = self.time_mlp(time)
140
+ hids1, hids2 = [], []
141
+
142
+ for n, down_block in enumerate(self.downs):
143
+ x, skip1, skip2 = down_block(x, t, conds)
144
+ hids1.append(skip1)
145
+ hids2.append(skip2)
146
+ if conds is not None:
147
+ conds = self.conditions[n](conds)
148
+
149
+ out = self.mid(x, t)
150
+
151
+ for up_block in self.ups:
152
+ out = up_block(out, t, hids1.pop(), hids2.pop())
153
+ return out
154
+
155
+ # constants
156
+ from collections import namedtuple
157
+ from torch.cuda.amp import autocast
158
+ import torch.nn.functional as F
159
+ from einops import reduce
160
+ from tqdm.auto import tqdm
161
+
162
+ ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start'])
163
+
164
+ def identity(t, *args, **kwargs):
165
+ return t
166
+
167
+ def extract(a, t, x_shape):
168
+ b, *_ = t.shape
169
+ out = a.gather(-1, t)
170
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
171
+
172
+ def default(val, d):
173
+ if exists(val):
174
+ return val
175
+ return d() if callable(d) else d
176
+
177
+ def exists(x):
178
+ return x is not None
179
+
180
+ def guidance_scheduler(sampling_step: int, const: float):
181
+ return const*torch.ones(sampling_step)
182
+
183
+ class GaussianDiffusion(nn.Module):
184
+ def __init__(
185
+ self,
186
+ vp_model,
187
+ diffusion,
188
+ timesteps = 1000,
189
+ sampling_timesteps = None,
190
+ objective = 'pred_v',
191
+ beta_schedule = 'sigmoid',
192
+ schedule_fn_kwargs = dict(),
193
+ ddim_sampling_eta = 0.,
194
+ offset_noise_strength = 0., # https://www.crosslabs.org/blog/diffusion-with-offset-noise
195
+ min_snr_loss_weight = False, # https://arxiv.org/abs/2303.09556
196
+ min_snr_gamma = 5
197
+ ):
198
+ super(GaussianDiffusion, self).__init__()
199
+
200
+ self.backbone = vp_model
201
+ self.diff_unet = diffusion
202
+
203
+ self.objective = objective
204
+ assert objective in {'pred_noise', 'pred_x0', 'pred_v'}, 'objective must be either pred_noise (predict noise) or pred_x0 (predict image start) or pred_v (predict v [v-parameterization as defined in appendix D of progressive distillation paper, used in imagen-video successfully])'
205
+
206
+ if beta_schedule == 'linear':
207
+ beta_schedule_fn = linear_beta_schedule
208
+ elif beta_schedule == 'cosine':
209
+ beta_schedule_fn = cosine_beta_schedule
210
+ elif beta_schedule == 'sigmoid':
211
+ beta_schedule_fn = sigmoid_beta_schedule
212
+ else:
213
+ raise ValueError(f'unknown beta schedule {beta_schedule}')
214
+
215
+ betas = beta_schedule_fn(timesteps, **schedule_fn_kwargs)
216
+
217
+ alphas = 1. - betas
218
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
219
+ alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
220
+
221
+ timesteps, = betas.shape
222
+ self.num_timesteps = int(timesteps)
223
+
224
+ # sampling related parameters
225
+
226
+ self.sampling_timesteps = default(sampling_timesteps, timesteps) # default num sampling timesteps to number of timesteps at training
227
+
228
+ assert self.sampling_timesteps <= timesteps
229
+ self.is_ddim_sampling = self.sampling_timesteps < timesteps
230
+ self.ddim_sampling_eta = ddim_sampling_eta
231
+
232
+ # helper function to register buffer from float64 to float32
233
+
234
+ register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
235
+
236
+ register_buffer('betas', betas)
237
+ register_buffer('alphas_cumprod', alphas_cumprod)
238
+ register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
239
+
240
+ # calculations for diffusion q(x_t | x_{t-1}) and others
241
+
242
+ register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
243
+ register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
244
+ register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
245
+ register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
246
+ register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
247
+
248
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
249
+
250
+ posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
251
+
252
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
253
+
254
+ register_buffer('posterior_variance', posterior_variance)
255
+
256
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
257
+
258
+ register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
259
+ register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
260
+ register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
261
+
262
+ # offset noise strength - in blogpost, they claimed 0.1 was ideal
263
+
264
+ self.offset_noise_strength = offset_noise_strength
265
+
266
+ # derive loss weight
267
+ # snr - signal noise ratio
268
+
269
+ snr = alphas_cumprod / (1 - alphas_cumprod)
270
+
271
+ # https://arxiv.org/abs/2303.09556
272
+
273
+ maybe_clipped_snr = snr.clone()
274
+ if min_snr_loss_weight:
275
+ maybe_clipped_snr.clamp_(max = min_snr_gamma)
276
+
277
+ if objective == 'pred_noise':
278
+ register_buffer('loss_weight', maybe_clipped_snr / snr)
279
+ elif objective == 'pred_x0':
280
+ register_buffer('loss_weight', maybe_clipped_snr)
281
+ elif objective == 'pred_v':
282
+ register_buffer('loss_weight', maybe_clipped_snr / (snr + 1))
283
+
284
+ @property
285
+ def device(self):
286
+ return self.betas.device
287
+
288
+ # CFG schdeuler => by taking pre-setting scheduler
289
+ def setup_guidance(self, scheduler):
290
+ if exists(scheduler):
291
+ self.CFG_sch = scheduler.to(self.device)
292
+ else:
293
+ self.CFG_sch = scheduler
294
+
295
+ def predict_start_from_noise(self, x_t, t, noise):
296
+ return (
297
+ extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
298
+ extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
299
+ )
300
+
301
+ def predict_noise_from_start(self, x_t, t, x0):
302
+ return (
303
+ (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \
304
+ extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
305
+ )
306
+
307
+ def predict_v(self, x_start, t, noise):
308
+ return (
309
+ extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise -
310
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start
311
+ )
312
+
313
+ def predict_start_from_v(self, x_t, t, v):
314
+ return (
315
+ extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
316
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
317
+ )
318
+
319
+ def q_posterior(self, x_start, x_t, t):
320
+ posterior_mean = (
321
+ extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
322
+ extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
323
+ )
324
+ posterior_variance = extract(self.posterior_variance, t, x_t.shape)
325
+ posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
326
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
327
+
328
+ def model_predictions(self, x, t, cond, clip_x_start = False, rederive_pred_noise = False):
329
+ # print(t.device)
330
+ if exists(self.CFG_sch):
331
+ uncond = self.diff_unet(x, t, conds=None) #conds=torch.zeros_like(cond))
332
+ model_output = self.diff_unet(x, t, conds=cond)
333
+ time = int(t[0])
334
+ model_output = model_output - self.CFG_sch[time] * (uncond - model_output)
335
+ else:
336
+ model_output = self.diff_unet(x, t, conds=cond)
337
+ maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity
338
+
339
+ if self.objective == 'pred_noise':
340
+ pred_noise = model_output
341
+ x_start = self.predict_start_from_noise(x, t, pred_noise)
342
+ x_start = maybe_clip(x_start)
343
+
344
+ if clip_x_start and rederive_pred_noise:
345
+ pred_noise = self.predict_noise_from_start(x, t, x_start)
346
+
347
+ elif self.objective == 'pred_x0':
348
+ x_start = model_output
349
+ x_start = maybe_clip(x_start)
350
+ pred_noise = self.predict_noise_from_start(x, t, x_start)
351
+
352
+ elif self.objective == 'pred_v':
353
+ v = model_output
354
+ x_start = self.predict_start_from_v(x, t, v)
355
+ x_start = maybe_clip(x_start)
356
+ pred_noise = self.predict_noise_from_start(x, t, x_start)
357
+
358
+ return ModelPrediction(pred_noise, x_start)
359
+
360
+ def p_mean_variance(self, x, t, cond=None, clip_denoised = True):
361
+ preds = self.model_predictions(x, t, cond=cond, clip_x_start=False,)
362
+ x_start = preds.pred_x_start
363
+
364
+ if clip_denoised:
365
+ x_start.clamp_(-1., 1.)
366
+
367
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t)
368
+ return model_mean, posterior_variance, posterior_log_variance, x_start
369
+
370
+ @torch.no_grad()
371
+ def p_sample(self, x, t: int, cond=None):
372
+ b, *_, device = *x.shape, self.device
373
+ batched_times = torch.full((b,), t, device = device, dtype = torch.long)
374
+ model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, cond=cond, clip_denoised = False)
375
+ noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0
376
+ pred_img = model_mean + (0.5 * model_log_variance).exp() * noise
377
+ return pred_img, x_start
378
+
379
+ @torch.no_grad()
380
+ def p_sample_loop(self, shape, cond=None, return_all_timesteps = False):
381
+ batch, device = shape[0], self.device
382
+
383
+ frames_pred = torch.randn(shape, device = device)
384
+ imgs = [frames_pred]
385
+
386
+ for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps, disable=True):
387
+ frames_pred, _ = self.p_sample(frames_pred, t, cond=cond)
388
+ imgs.append(frames_pred)
389
+
390
+ ret = frames_pred if not return_all_timesteps else torch.stack(imgs, dim = 1)
391
+ return ret
392
+
393
+ @torch.no_grad()
394
+ def ddim_sample(self, shape, cond=None, return_all_timesteps = False):
395
+ batch, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective
396
+ device = self.device
397
+ times = torch.linspace(-1, total_timesteps - 1, steps = sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
398
+ times = list(reversed(times.int().tolist()))
399
+ time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
400
+
401
+ frames_pred = torch.randn(shape, device = device)
402
+ imgs = [frames_pred]
403
+
404
+ for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step', disable=True):
405
+ time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
406
+ pred_noise, x_start, *_ = self.model_predictions(
407
+ frames_pred,
408
+ time_cond,
409
+ cond = cond, #cond.copy(),
410
+ clip_x_start = False,
411
+ rederive_pred_noise = True
412
+ )
413
+
414
+ if time_next < 0:
415
+ frames_pred = x_start
416
+ imgs.append(frames_pred)
417
+ continue
418
+
419
+ alpha = self.alphas_cumprod[time]
420
+ alpha_next = self.alphas_cumprod[time_next]
421
+
422
+ sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
423
+ c = (1 - alpha_next - sigma ** 2).sqrt()
424
+
425
+ noise = torch.randn_like(frames_pred)
426
+
427
+ frames_pred = x_start * alpha_next.sqrt() + \
428
+ c * pred_noise + \
429
+ sigma * noise
430
+
431
+ imgs.append(frames_pred)
432
+
433
+ ret = frames_pred if not return_all_timesteps else torch.stack(imgs, dim = 1)
434
+ return ret
435
+
436
+ @torch.no_grad()
437
+ def sample(self, frames_in, return_all_timesteps = False):
438
+ assert frames_in.ndim == 5
439
+ B, T_in, C, H, W = frames_in.shape
440
+ device = self.device
441
+
442
+ backbone_output, conds, *_ = self.backbone(frames_in) # updated for Updated loss function on 03/07
443
+ sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample
444
+
445
+ *_, c, h, w = conds.shape
446
+ tgt_shape = conds.reshape(B, -1, c, h, w).shape
447
+ ldm_pred = sample_fn(
448
+ tgt_shape,
449
+ cond=conds,
450
+ return_all_timesteps = return_all_timesteps
451
+ )
452
+
453
+ ldm_pred = rearrange(ldm_pred, 'b t c h w -> (b t) c h w')
454
+ frames_pred = self.backbone.vae.decode(ldm_pred)
455
+ frames_pred = rearrange(frames_pred, '(b t) c h w -> b t c h w', b=B)
456
+ return frames_pred, backbone_output
457
+
458
+ def predict(self, frames_in, compute_loss=False, **kwargs):
459
+ pred, mu = self.sample(frames_in=frames_in)
460
+ return pred, mu
461
+
462
+ def compute_loss(self, frames_in, frames_gt, validate=False):
463
+ compute_loss = True and (not validate)
464
+ B, T_in, C, H, W = frames_in.shape
465
+ T_out = frames_gt.shape[1]
466
+ device = frames_in.device
467
+
468
+ """
469
+ Diffusion Loss
470
+ """
471
+ backbone_output, conds = self.backbone(frames_in)
472
+ hid_gt, _ = self.backbone.vae.encode(
473
+ rearrange(frames_gt, 'b t c h w -> (b t) c h w')
474
+ )
475
+ hid_gt = rearrange(hid_gt, '(b t) c h w -> b t c h w', b=B)
476
+ t = torch.randint(0, self.num_timesteps, (B,), device=self.device).long()
477
+ if random.random() > 0.85: # Unconditional
478
+ conds = None
479
+ diff_loss = self.p_losses(hid_gt.detach(), t, cond=conds)
480
+
481
+ """
482
+ Backbone Loss
483
+ """
484
+ mu_loss = self.backbone._losses_(frames_in, frames_gt)
485
+
486
+ """
487
+ VAE Loss
488
+ """
489
+ ae_loss, kl_loss = self.backbone.vae._losses_(
490
+ rearrange(torch.cat((frames_in, frames_gt), dim=1), 'b t c h w -> (b t) c h w'),
491
+ rearrange(torch.cat((frames_in, frames_gt), dim=1), 'b t c h w -> (b t) c h w')
492
+ )
493
+ kl_weight = 1E-6
494
+ recon_loss = ae_loss + kl_weight*kl_loss
495
+
496
+ """
497
+ Prior Loss at t=T [Noisy]
498
+ """
499
+ hid_gt, _ = self.backbone.vae.encode(
500
+ rearrange(frames_gt, 'b t c h w -> (b t) c h w')
501
+ )
502
+ hid_gt = rearrange(hid_gt, '(b t) c h w -> b t c h w', b=B)
503
+ T = torch.ones((B,), device=self.device).long() * (self.num_timesteps - 1)
504
+ mu_noisy = extract(self.sqrt_alphas_cumprod, T, hid_gt.shape) * hid_gt
505
+ sigma_noisy = extract(self.sqrt_one_minus_alphas_cumprod, T, hid_gt.shape)
506
+ log_var_noisy = 2*torch.log(sigma_noisy)
507
+ prior_loss = self.kl_from_standard_normal(mu_noisy, log_var_noisy)
508
+
509
+ return recon_loss, mu_loss, diff_loss, prior_loss
510
+
511
+ def kl_from_standard_normal(self, mean, log_var):
512
+ kl = 0.5 * (log_var.exp() + mean.square() - 1.0 - log_var)
513
+ return kl.mean()
514
+
515
+ @autocast(enabled = False)
516
+ def q_sample(self, x_start, t, noise = None):
517
+ noise = default(noise, lambda: torch.randn_like(x_start))
518
+
519
+ return (
520
+ extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
521
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
522
+ )
523
+
524
+ def p_losses(self, x_start, t, noise=None, offset_noise_strength=None, cond=None):
525
+ b, T, c, h, w = x_start.shape
526
+
527
+ noise = default(noise, lambda: torch.randn_like(x_start))
528
+
529
+ # offset noise - https://www.crosslabs.org/blog/diffusion-with-offset-noise
530
+ offset_noise_strength = default(offset_noise_strength, self.offset_noise_strength)
531
+
532
+ if offset_noise_strength > 0.:
533
+ offset_noise = torch.randn(x_start.shape[:2], device = self.device)
534
+ noise += offset_noise_strength * rearrange(offset_noise, 'b c -> b c 1 1')
535
+
536
+ # noise sample
537
+ x = self.q_sample(x_start=x_start, t=t, noise=noise) # Use q_sample here for updating: https://github.com/lucidrains/denoising-diffusion-pytorch/blob/main/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L763
538
+
539
+ model_out = self.diff_unet(x, t, conds=cond)
540
+
541
+ if self.objective == 'pred_noise':
542
+ target = noise
543
+ elif self.objective == 'pred_x0':
544
+ target = x_start
545
+ elif self.objective == 'pred_v':
546
+ v = self.predict_v(x_start, t, noise)
547
+ target = v
548
+ else:
549
+ raise ValueError(f'unknown objective {self.objective}')
550
+
551
+ loss = F.mse_loss(model_out, target, reduction = 'none') # (B, T, C, H, W)
552
+ loss = reduce(loss, 'b ... -> b', 'mean')
553
+
554
+ loss = loss * extract(self.loss_weight, t, loss.shape)
555
+ return loss.mean()
556
+
557
+ @torch.no_grad()
558
+ def forward(self, input_x, include_mu=False, **kwargs):
559
+ pred, mu = self.predict(input_x, compute_loss=False)
560
+ if include_mu:
561
+ return pred, mu
562
+ else:
563
+ return pred
564
+
565
+ from stldm.modules import SimVPV2_Model, VAE
566
+ def model_setup(model_config, print_info=False, cfg_str=None):
567
+ if print_info:
568
+ print('Setup a Spatial diffusion with replacing a Temporal attention with Spatial attention')
569
+ print('This is a diffusion with the consideration of (BT, C, H, W)')
570
+ print('Train it from end to end')
571
+
572
+ vp_config = model_config['vp_param']
573
+ ldm_config = model_config['stldm_param']
574
+
575
+ vpm = SimVPV2_Model(**vp_config)
576
+ ldm = LDM(**ldm_config)
577
+ model = GaussianDiffusion(vp_model=vpm, diffusion=ldm, **model_config['param'])
578
+
579
+ scheduler = guidance_scheduler(sampling_step=model_config['param']['timesteps'], const=cfg_str) if cfg_str is not None else None
580
+ model.setup_guidance(scheduler)
581
+
582
+ return model
583
+
584
+ def ae_setup(model_config):
585
+ vp_config = model_config['vp_param']
586
+ vpm = SimVPV2_Model(**vp_config)
587
+ ae = vpm.vae
588
+ return ae
589
+
590
+ def backbone_setup(model_config):
591
+ vp_config = model_config['vp_param']
592
+ vpm = SimVPV2_Model(**vp_config)
593
+ return vpm
stldm/submodules.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, math
2
+ from torch import nn
3
+ from einops import rearrange
4
+
5
+ # building block modules
6
+ def exists(x):
7
+ return x is not None
8
+
9
+ class LayerNorm(nn.Module):
10
+ def __init__(self, dim):
11
+ super().__init__()
12
+ self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
13
+
14
+ def forward(self, x):
15
+ eps = 1e-5 if x.dtype == torch.float32 else 1e-3
16
+ var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
17
+ mean = torch.mean(x, dim = 1, keepdim = True)
18
+ return (x - mean) * (var + eps).rsqrt() * self.g
19
+
20
+
21
+ class PreNorm(nn.Module):
22
+ def __init__(self, dim, fn):
23
+ super().__init__()
24
+ self.fn = fn
25
+ self.norm = LayerNorm(dim)
26
+
27
+ def forward(self, x, *args, **kwargs):
28
+ x = self.norm(x)
29
+ return self.fn(x, *args, **kwargs)
30
+
31
+ class Residual(nn.Module):
32
+ def __init__(self, fn):
33
+ super().__init__()
34
+ self.fn = fn
35
+
36
+ def forward(self, x, *args, **kwargs):
37
+ return self.fn(x, *args, **kwargs) + x
38
+
39
+ class Block(nn.Module):
40
+ def __init__(self, dim, dim_out, groups = 8):
41
+ super().__init__()
42
+ self.proj = nn.Conv2d(dim, dim_out, 3, padding = 1)
43
+ if dim_out%groups != 0:
44
+ groups = 1
45
+ self.norm = nn.GroupNorm(groups, dim_out)
46
+ self.act = nn.SiLU()
47
+
48
+ def forward(self, x, scale_shift = None):
49
+ x = self.proj(x)
50
+ x = self.norm(x)
51
+
52
+ if exists(scale_shift):
53
+ scale, shift = scale_shift
54
+ x = x * (scale + 1) + shift
55
+
56
+ x = self.act(x)
57
+ return x
58
+
59
+ class ResnetBlock(nn.Module):
60
+ def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8):
61
+ super().__init__()
62
+ self.mlp = nn.Sequential(
63
+ nn.SiLU(),
64
+ nn.Linear(time_emb_dim, dim_out * 2)
65
+ ) if exists(time_emb_dim) else None
66
+
67
+ self.block1 = Block(dim, dim_out, groups = groups)
68
+ self.block2 = Block(dim_out, dim_out, groups = groups)
69
+ self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
70
+
71
+ def forward(self, x, time_emb = None):
72
+
73
+ scale_shift = None
74
+ if exists(self.mlp) and exists(time_emb):
75
+ time_emb = self.mlp(time_emb)
76
+ time_emb = rearrange(time_emb, 'b c -> b c 1 1')
77
+ scale_shift = time_emb.chunk(2, dim = 1)
78
+
79
+ h = self.block1(x, scale_shift = scale_shift)
80
+
81
+ h = self.block2(h)
82
+
83
+ return h + self.res_conv(x)
84
+
85
+ """
86
+ Input Tensor and Output Tensor should be in the format of (BT, C, H, W) with # dims = 4
87
+ """
88
+ class Linear_SpatialAttention(nn.Module):
89
+ def __init__(self, dim, patch_size, heads=4, dim_head=32):
90
+ super(Linear_SpatialAttention, self).__init__()
91
+ self.scale = dim_head ** -0.5
92
+ self.patch_size = patch_size
93
+ self.heads = heads
94
+ hidden_dim = dim_head*heads # No of Channel for (Q, K, V)
95
+ self.to_qkv = nn.Conv2d(dim, hidden_dim*3, kernel_size=1, padding=0, bias=False)
96
+ self.to_out = nn.Sequential(
97
+ nn.Conv2d(hidden_dim, dim, kernel_size=1),
98
+ LayerNorm(dim)
99
+ )
100
+
101
+ def forward(self, x):
102
+ assert x.ndim == 4
103
+ BT, C, H, W = x.shape
104
+ nh, nw = H//self.patch_size, W//self.patch_size
105
+ qkv = self.to_qkv(x).chunk(3, dim=1) # qkv tuple in (q, k , v)
106
+ # [B, Head × C, X × P, Y × P] -> [B, Head × X × Y, C, P × P]
107
+ q, k, v = map(lambda t: rearrange(t, 'b (h c) (nh ph) (nw pw) -> b (h nh nw) c (ph pw)', h=self.heads, ph=self.patch_size, pw=self.patch_size, nh=nh, nw=nw), qkv)
108
+ q = q.softmax(dim=-2)
109
+ k = k.softmax(dim=-1)
110
+ q = q*self.scale
111
+
112
+ context = torch.einsum('b h d n, b h e n -> b h d e', k, v)
113
+ out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
114
+ out = rearrange(out, 'b (h nh nw) c (ph pw) -> b (h c) (nh ph) (nw pw)', h=self.heads, ph=self.patch_size, pw=self.patch_size, nh=nh, nw=nw)
115
+ out = self.to_out(out)
116
+ return out
117
+
118
+ """
119
+ Input Tensor and Output Tensor should be in the format of (B, T, C, H, W) with # dims = 5
120
+ """
121
+ class Linear_TemporalAttention(nn.Module):
122
+ def __init__(self, dim, heads=4, dim_head=32):
123
+ super(Linear_TemporalAttention, self).__init__()
124
+ self.scale = dim_head ** -0.5
125
+ self.heads = heads
126
+ hidden_dim = dim_head*heads # No of Channel for (Q, K, V)
127
+ self.to_qkv = nn.Conv2d(dim, hidden_dim*3, kernel_size=1, padding=0, bias=False)
128
+ self.to_out = nn.Sequential(
129
+ nn.Conv2d(hidden_dim, dim, kernel_size=1),
130
+ LayerNorm(dim)
131
+ )
132
+
133
+ def forward(self, x):
134
+ assert x.ndim == 5
135
+ B, T, C, H, W = x.shape
136
+ x = x.reshape(B*T, C, H, W)
137
+ qkv = self.to_qkv(x).chunk(3, dim=1) # qkv tuple in (q, k , v)
138
+ # [B, Head × C, X × P, Y × P] -> [B, Head × X × Y, C, P × P]
139
+ q, k, v = map(lambda t: rearrange(t, '(b t) (h c) x y -> b (h x y) c t', h=self.heads, x=H, y=W, t=T), qkv)
140
+ q = q.softmax(dim=-2)
141
+ k = k.softmax(dim=-1)
142
+ q = q*self.scale
143
+ v /= (H*W)
144
+
145
+ context = torch.einsum('b h d n, b h e n -> b h d e', k, v)
146
+ out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
147
+ out = rearrange(out, 'b (h x y) c t -> (b t) (h c) x y', h=self.heads, x=H, y=W, t=T)
148
+ out = self.to_out(out)
149
+ return out.reshape(B, T, C, H, W)
150
+
151
+ # Does not Follow what suggested by the paper as could not ensure the spatial factor of 2
152
+ def Downsample2D(dim_in, dim_out):
153
+ return nn.Conv2d(dim_in, dim_out, kernel_size=(4, 4), stride=(2, 2), padding=(1,1))
154
+
155
+ def Upsample2D(dim_in, dim_out):
156
+ return nn.ConvTranspose2d(dim_in, dim_out, kernel_size=(4, 4), stride=(2, 2), padding=(1,1))
157
+
158
+ def ChannelConversion(dim_in, dim_out):
159
+ return nn.Conv2d(dim_in, dim_out, kernel_size=(3,3), padding=(1,1))
160
+
161
+ """
162
+ Input Tensor and Output Tensor should be in the format of (BT, C, H, W) with # dims = 4
163
+ """
164
+ class Quadratic_SpatialAttention(nn.Module):
165
+ def __init__(self, dim, heads=4, dim_head=32):
166
+ super(Quadratic_SpatialAttention, self).__init__()
167
+ self.scale = dim_head ** -0.5
168
+ self.heads = heads
169
+ hidden_dim = dim_head*heads # No of Channel for (Q, K, V)
170
+ self.to_qkv = nn.Conv2d(dim, hidden_dim*3, kernel_size=1, padding=0, bias=False)
171
+ self.to_out = nn.Sequential(
172
+ nn.Conv2d(hidden_dim, dim, kernel_size=1)
173
+ )
174
+
175
+ def forward(self, x):
176
+ assert x.ndim == 4
177
+ BT, C, H, W = x.shape
178
+ qkv = self.to_qkv(x).chunk(3, dim=1) # qkv tuple in (q, k , v)
179
+ # [B, Head × C, X × P, Y × P] -> [B, Head × X × Y, C, P × P]
180
+ q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h=self.heads), qkv)
181
+ q = q*self.scale
182
+
183
+ sim = torch.einsum('b h d i, b h d j -> b h i j', q, k)
184
+ attn = sim.softmax(dim = -1)
185
+ out = torch.einsum('b h i j, b h d j -> b h i d', attn, v)
186
+
187
+ out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = H, y = W)
188
+
189
+ out = self.to_out(out)
190
+ return out
191
+
192
+ """
193
+ Input Tensor and Output Tensor should be in the format of (B, T, C, H, W) with # dims = 5
194
+ """
195
+ class Quadratic_TemporalAttention(nn.Module):
196
+ def __init__(self, dim, heads=4, dim_head=32):
197
+ super(Quadratic_TemporalAttention, self).__init__()
198
+ self.scale = dim_head ** -0.5
199
+ self.heads = heads
200
+ hidden_dim = dim_head*heads # No of Channel for (Q, K, V)
201
+ self.to_qkv = nn.Conv2d(dim, hidden_dim*3, kernel_size=1, padding=0, bias=False)
202
+ self.to_out = nn.Sequential(
203
+ nn.Conv2d(hidden_dim, dim, kernel_size=1),
204
+ )
205
+
206
+ def forward(self, x):
207
+ assert x.ndim == 5
208
+ B, T, C, H, W = x.shape
209
+ x = x.reshape(B*T, C, H, W)
210
+ qkv = self.to_qkv(x).chunk(3, dim=1) # qkv tuple in (q, k , v)
211
+ # [B, Head × C, X × P, Y × P] -> [B, Head × X × Y, C, P × P]
212
+ q, k, v = map(lambda t: rearrange(t, '(b t) (h c) x y -> b h (c x y) t', h=self.heads, x=H, y=W, t=T), qkv)
213
+ q = q*self.scale
214
+
215
+ sim = torch.einsum('b h d i, b h d j -> b h i j', q, k)
216
+ attn = sim.softmax(dim = -1)
217
+ out = torch.einsum('b h i j, b h d j -> b h i d', attn, v)
218
+ out = rearrange(out, 'b h t (c x y) -> (b t) (h c) x y', h=self.heads, x=H, y=W, t=T)
219
+ out = self.to_out(out)
220
+ return out.reshape(B, T, C, H, W)
221
+
222
+ """
223
+ A series of functions required for Diffusion Model copied from DiffCast code
224
+ """
225
+ def extract(a, t, x_shape):
226
+ b, *_ = t.shape
227
+ out = a.gather(-1, t)
228
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
229
+
230
+ def linear_beta_schedule(timesteps):
231
+ """
232
+ linear schedule, proposed in original ddpm paper
233
+ """
234
+ scale = 1000 / timesteps
235
+ beta_start = scale * 0.0001
236
+ beta_end = scale * 0.02
237
+ return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64)
238
+
239
+ def cosine_beta_schedule(timesteps, s = 0.008):
240
+ """
241
+ cosine schedule
242
+ as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
243
+ """
244
+ steps = timesteps + 1
245
+ t = torch.linspace(0, timesteps, steps, dtype = torch.float64) / timesteps
246
+ alphas_cumprod = torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** 2
247
+ alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
248
+ betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
249
+ return torch.clip(betas, 0, 0.999)
250
+
251
+ def sigmoid_beta_schedule(timesteps, start = -3, end = 3, tau = 1, clamp_min = 1e-5):
252
+ """
253
+ sigmoid schedule
254
+ proposed in https://arxiv.org/abs/2212.11972 - Figure 8
255
+ better for images > 64x64, when used during training
256
+ """
257
+ steps = timesteps + 1
258
+ t = torch.linspace(0, timesteps, steps, dtype = torch.float64) / timesteps
259
+ v_start = torch.tensor(start / tau).sigmoid()
260
+ v_end = torch.tensor(end / tau).sigmoid()
261
+ alphas_cumprod = (-((t * (end - start) + start) / tau).sigmoid() + v_end) / (v_end - v_start)
262
+ alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
263
+ betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
264
+ return torch.clip(betas, 0, 0.999)
265
+
266
+ # sinusoidal positional embeds
267
+ class SinusoidalPosEmb(nn.Module):
268
+ def __init__(self, dim):
269
+ super(SinusoidalPosEmb, self).__init__()
270
+ self.dim = dim
271
+
272
+ def forward(self, x):
273
+ device = x.device
274
+ half_dim = self.dim // 2
275
+ emb = math.log(10000) / (half_dim - 1)
276
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
277
+ emb = x[:, None] * emb[None, :]
278
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
279
+ return emb
280
+
281
+ class Time_MLP(nn.Module):
282
+ def __init__(self, dim, time_dim, fourier_dim=32):
283
+ super(Time_MLP, self).__init__()
284
+ self.mlp = nn.Sequential(
285
+ SinusoidalPosEmb(fourier_dim),
286
+ nn.Linear(fourier_dim, time_dim),
287
+ nn.GELU(),
288
+ nn.Linear(time_dim, time_dim)
289
+ )
290
+
291
+ def forward(self, x):
292
+ return self.mlp(x)
293
+
294
+ class RelativePositionBias(nn.Module):
295
+ def __init__(
296
+ self,
297
+ heads = 8,
298
+ num_buckets = 32,
299
+ max_distance = 128
300
+ ):
301
+ super().__init__()
302
+ self.num_buckets = num_buckets
303
+ self.max_distance = max_distance
304
+ self.relative_attention_bias = nn.Embedding(num_buckets, heads)
305
+
306
+ @staticmethod
307
+ def _relative_position_bucket(relative_position, num_buckets = 32, max_distance = 128):
308
+ ret = 0
309
+ n = -relative_position
310
+
311
+ num_buckets //= 2
312
+ ret += (n < 0).long() * num_buckets
313
+ n = torch.abs(n)
314
+
315
+ max_exact = num_buckets // 2
316
+ is_small = n < max_exact
317
+
318
+ val_if_large = max_exact + (
319
+ torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
320
+ ).long()
321
+ val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
322
+
323
+ ret += torch.where(is_small, n, val_if_large)
324
+ return ret
325
+
326
+ def forward(self, n, device):
327
+ q_pos = torch.arange(n, dtype = torch.long, device = device)
328
+ k_pos = torch.arange(n, dtype = torch.long, device = device)
329
+ rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
330
+ rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance)
331
+ values = self.relative_attention_bias(rp_bucket)
332
+ return rearrange(values, 'i j h -> h i j')
333
+
334
+ """
335
+ Input Tensor and Output Tensor should be in the format of (B, T, C, H, W) with # dims = 5
336
+ """
337
+ class TemporalAttention_Pos(nn.Module):
338
+ def __init__(self, dim, heads=4, dim_head=32):
339
+ super(TemporalAttention_Pos, self).__init__()
340
+ self.scale = dim_head ** -0.5
341
+ self.heads = heads
342
+ hidden_dim = dim_head*heads # No of Channel for (Q, K, V)
343
+ self.to_qkv = nn.Conv2d(dim, hidden_dim*3, kernel_size=1, padding=0)
344
+ self.to_out = nn.Sequential(
345
+ nn.Conv2d(hidden_dim, dim, kernel_size=1),
346
+ )
347
+
348
+ def forward(self, x, rel_pos=None):
349
+ assert x.ndim == 5
350
+ B, T, C, H, W = x.shape
351
+ x = x.reshape(B*T, C, H, W)
352
+ qkv = self.to_qkv(x).chunk(3, dim=1) # qkv tuple in (q, k , v)
353
+ # [B, Head × C, X × P, Y × P] -> [B, Head × X × Y, C, P × P]
354
+ q, k, v = map(lambda t: rearrange(t, '(b t) (h c) x y -> (b x y) h c t', h=self.heads, x=H, y=W, t=T), qkv)
355
+ q = q*self.scale
356
+
357
+ sim = torch.einsum('b h d i, b h d j -> b h i j', q, k)
358
+ if rel_pos is not None:
359
+ sim += rel_pos
360
+ attn = sim.softmax(dim = -1)
361
+ out = torch.einsum('b h i j, b h d j -> b h i d', attn, v)
362
+ out = rearrange(out, '(b x y) h t c -> (b t) (h c) x y', h=self.heads, x=H, y=W, t=T)
363
+ out = self.to_out(out)
364
+ return out.reshape(B, T, C, H, W)
365
+
366
+
367
+ class TemporalAttention(nn.Module):
368
+ def __init__(self, dim, heads=4, dim_head=32):
369
+ super(TemporalAttention, self).__init__()
370
+ self.scale = dim_head ** -0.5
371
+ self.heads = heads
372
+ hidden_dim = dim_head*heads
373
+ self.to_k = nn.Linear(dim, hidden_dim, bias=False)
374
+ self.to_q = nn.Linear(dim, hidden_dim, bias=False)
375
+ self.to_v = nn.Linear(dim, hidden_dim, bias=False)
376
+ self.to_out = nn.Linear(hidden_dim, dim)
377
+
378
+ def forward(self, x):
379
+ assert x.ndim == 5
380
+ B, T, C, H, W = x.shape
381
+ x = rearrange(x, 'b t c h w -> b (h w) t c')
382
+
383
+ q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)
384
+ q = rearrange(q, '... n (h d) -> ... h n d', h=self.heads) # B (H W) Head T Dim
385
+ k = rearrange(k, '... n (h d) -> ... h n d', h=self.heads)
386
+ v = rearrange(v, '... n (h d) -> ... h n d', h=self.heads)
387
+ q = q*self.scale
388
+
389
+ sim = torch.einsum('... h i d, ... h j d -> ... h i j', q, k)
390
+ attn = sim.softmax(dim=-1)
391
+ out = torch.einsum('... h i j, ... h j d -> ... h i d', attn, v)
392
+ out = rearrange(out, '... h i d -> ... i (h d)', h=self.heads)
393
+ out = self.to_out(out)
394
+ out = rearrange(out, 'b (h w) t c -> b t c h w', h=H, w=W)
395
+ return out
utilspp.py ADDED
@@ -0,0 +1,566 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ import lpips as lp
5
+ import pandas as pd
6
+ import torchmetrics
7
+ import matplotlib.pyplot as plt
8
+ from bisect import bisect_right
9
+ import torchvision.transforms as T
10
+ from torch import nn
11
+
12
+ from matplotlib.colors import ListedColormap, BoundaryNorm
13
+ from matplotlib.lines import Line2D
14
+
15
+ from data import dutils
16
+
17
+ # =======================================================================
18
+ # Scheduler Helper Function
19
+ # =======================================================================
20
+
21
+ class SequentialLR(torch.optim.lr_scheduler._LRScheduler):
22
+ """Receives the list of schedulers that is expected to be called sequentially during
23
+ optimization process and milestone points that provides exact intervals to reflect
24
+ which scheduler is supposed to be called at a given epoch.
25
+
26
+ Args:
27
+ schedulers (list): List of chained schedulers.
28
+ milestones (list): List of integers that reflects milestone points.
29
+
30
+ Example:
31
+ >>> # Assuming optimizer uses lr = 1. for all groups
32
+ >>> # lr = 0.1 if epoch == 0
33
+ >>> # lr = 0.1 if epoch == 1
34
+ >>> # lr = 0.9 if epoch == 2
35
+ >>> # lr = 0.81 if epoch == 3
36
+ >>> # lr = 0.729 if epoch == 4
37
+ >>> scheduler1 = ConstantLR(self.opt, factor=0.1, total_iters=2)
38
+ >>> scheduler2 = ExponentialLR(self.opt, gamma=0.9)
39
+ >>> scheduler = SequentialLR(self.opt, schedulers=[scheduler1, scheduler2], milestones=[2])
40
+ >>> for epoch in range(100):
41
+ >>> train(...)
42
+ >>> validate(...)
43
+ >>> scheduler.step()
44
+ """
45
+
46
+ def __init__(self, optimizer, schedulers, milestones, last_epoch=-1, verbose=False):
47
+ for scheduler_idx in range(1, len(schedulers)):
48
+ if (schedulers[scheduler_idx].optimizer != schedulers[0].optimizer):
49
+ raise ValueError(
50
+ "Sequential Schedulers expects all schedulers to belong to the same optimizer, but "
51
+ "got schedulers at index {} and {} to be different".format(0, scheduler_idx)
52
+ )
53
+ if (len(milestones) != len(schedulers) - 1):
54
+ raise ValueError(
55
+ "Sequential Schedulers expects number of schedulers provided to be one more "
56
+ "than the number of milestone points, but got number of schedulers {} and the "
57
+ "number of milestones to be equal to {}".format(len(schedulers), len(milestones))
58
+ )
59
+ self.optimizer = optimizer
60
+ self._schedulers = schedulers
61
+ self._milestones = milestones
62
+ self.last_epoch = last_epoch + 1
63
+
64
+ def step(self, ref=None):
65
+ self.last_epoch += 1
66
+ idx = bisect_right(self._milestones, self.last_epoch)
67
+ if idx > 0 and self._milestones[idx - 1] == self.last_epoch:
68
+ self._schedulers[idx].step(0)
69
+ else:
70
+ # Check HERE
71
+ if isinstance(self._schedulers[idx], torch.optim.lr_scheduler.ReduceLROnPlateau):
72
+ self._schedulers[idx].step(ref)
73
+ else:
74
+ self._schedulers[idx].step()
75
+
76
+ def state_dict(self):
77
+ """Returns the state of the scheduler as a :class:`dict`.
78
+
79
+ It contains an entry for every variable in self.__dict__ which
80
+ is not the optimizer.
81
+ The wrapped scheduler states will also be saved.
82
+ """
83
+ state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', '_schedulers')}
84
+ state_dict['_schedulers'] = [None] * len(self._schedulers)
85
+
86
+ for idx, s in enumerate(self._schedulers):
87
+ state_dict['_schedulers'][idx] = s.state_dict()
88
+
89
+ return state_dict
90
+
91
+ def load_state_dict(self, state_dict):
92
+ """Loads the schedulers state.
93
+
94
+ Args:
95
+ state_dict (dict): scheduler state. Should be an object returned
96
+ from a call to :meth:`state_dict`.
97
+ """
98
+ _schedulers = state_dict.pop('_schedulers')
99
+ self.__dict__.update(state_dict)
100
+ # Restore state_dict keys in order to prevent side effects
101
+ # https://github.com/pytorch/pytorch/issues/32756
102
+ state_dict['_schedulers'] = _schedulers
103
+
104
+ for idx, s in enumerate(_schedulers):
105
+ self._schedulers[idx].load_state_dict(s)
106
+
107
+ def warmup_lambda(warmup_steps, min_lr_ratio=0.1):
108
+ def ret_lambda(epoch):
109
+ if epoch <= warmup_steps:
110
+ return min_lr_ratio + (1.0 - min_lr_ratio) * epoch / warmup_steps
111
+ else:
112
+ return 1.0
113
+ return ret_lambda
114
+
115
+ # =======================================================================
116
+ # Utils in utils :)
117
+ # =======================================================================
118
+ def to_cpu_tensor(*args):
119
+ '''
120
+ Input arbitrary number of array/tensors, each will be converted to CPU torch.Tensor
121
+ '''
122
+ out = []
123
+ for tensor in args:
124
+ if type(tensor) is np.ndarray:
125
+ tensor = torch.Tensor(tensor)
126
+ if type(tensor) is torch.Tensor:
127
+ tensor = tensor.cpu()
128
+ out.append(tensor)
129
+ # single value input: return single value output
130
+ if len(out) == 1:
131
+ return out[0]
132
+ return out
133
+
134
+ def merge_leading_dims(tensor, n=2):
135
+ '''
136
+ Merge the first N dimension of a tensor
137
+ '''
138
+ return tensor.reshape((-1, *tensor.shape[n:]))
139
+
140
+ # =======================================================================
141
+ # Model Preparation, saving & loading (copied from utils.py)
142
+ # =======================================================================
143
+ def build_model_name(model_type, model_config):
144
+ '''
145
+ Build the model name (without extension)
146
+ '''
147
+ model_name = model_type + '_'
148
+ for k, v in model_config.items():
149
+ model_name += k
150
+ if type(v) is list or type(v) is tuple:
151
+ model_name += '-'
152
+ for i, item in enumerate(v):
153
+ model_name += (str(item) if type(item) is not bool else '') + ('-' if i < len(v)-1 else '')
154
+ else:
155
+ model_name += (('-' + str(v)) if type(v) is not bool else '')
156
+ model_name += '_'
157
+ return model_name[:-1]
158
+
159
+ def build_model_path(base_dir, dataset_type, model_type, timestamp=None):
160
+ if timestamp is None:
161
+ return os.path.join(base_dir, dataset_type, model_type)
162
+ elif timestamp == True:
163
+ return os.path.join(base_dir, dataset_type, model_type, pd.Timestamp.now().strftime('%Y%m%d%H%M%S'))
164
+ return os.path.join(base_dir, dataset_type, model_type, timestamp)
165
+
166
+ # =======================================================================
167
+ # Preprocess Function for Loading HKO-7 dataset
168
+ # =======================================================================
169
+
170
+ def hko7_preprocess(x_seq, x_mask, dt_clip, args):
171
+ resize = args.resize if 'resize' in args else x_seq.shape[-1]
172
+ seq_len = args.seq_len if 'seq_len' in args else 5
173
+
174
+ # post-process on HKO-10
175
+ x_seq = x_seq.transpose((1, 0, 2, 3, 4)) / 255. # => (batch_size, seq_length, 1, 480, 480)
176
+ if 'scale' in args and args.scale == 'non-linear':
177
+ x_seq = dutils.linear_to_nonlinear_batched(x_seq, dt_clip)
178
+ else:
179
+ x_seq = dutils.nonlinear_to_linear_batched(x_seq, dt_clip)
180
+
181
+ b, t, c, h, w = x_seq.shape
182
+ assert c == 1, f'# channels ({c}) != 1'
183
+
184
+ # resize (downsample) the images if necessary
185
+ x_seq = torch.Tensor(x_seq).float().reshape((b*t, c, h, w))
186
+ if resize != h:
187
+ tform = T.Compose([
188
+ T.ToPILImage(),
189
+ T.Resize(resize),
190
+ T.ToTensor(),
191
+ ])
192
+ else:
193
+ tform = T.Compose([])
194
+
195
+ x_seq = torch.stack([tform(x_frame) for x_frame in x_seq], dim=0)
196
+ x_seq = x_seq.reshape((b, t, c, resize, resize))
197
+
198
+ x, y = x_seq[:, :seq_len], x_seq[:, seq_len:]
199
+ return x, y
200
+
201
+ # =======================================================================
202
+ # Evaluation Metrics-Related
203
+ # =======================================================================
204
+
205
+ mae = lambda *args: torch.nn.functional.l1_loss(*args).cpu().detach().numpy()
206
+ mse = lambda *args: torch.nn.functional.mse_loss(*args).cpu().detach().numpy()
207
+
208
+ def ssim(y_pred, y):
209
+ y, y_pred = to_cpu_tensor(y, y_pred)
210
+ b, t, c, h, w = y.shape
211
+ y = y.reshape((b*t, c, h, w))
212
+ y_pred = y_pred.reshape((b*t, c, h, w))
213
+ # to further ensure any of the input is not negative
214
+ y = torch.clamp(y, 0, 1)
215
+ y_pred = torch.clamp(y_pred, 0, 1)
216
+ return torchmetrics.image.StructuralSimilarityIndexMeasure(data_range=1.0)(y_pred, y)
217
+
218
+ def psnr(y_pred, y):
219
+ y, y_pred = to_cpu_tensor(y, y_pred)
220
+ b, t, c, h, w = y.shape
221
+ y = y.reshape((b*t, c, h, w))
222
+ y_pred = y_pred.reshape((b*t, c, h, w))
223
+ acc_score = 0
224
+ for i in range(b*t):
225
+ acc_score += torchmetrics.image.PeakSignalNoiseRatio(data_range=1.0)(y_pred[i], y[i]) / (b*t)
226
+ return acc_score
227
+
228
+ GLOBAL_LPIPS_OBJ = None # a static variable
229
+ def lpips64(y_pred, y, net='vgg'):
230
+ # convert the image range into [-1, 1], assuming the input range to be [0, 1]
231
+ y = merge_leading_dims(y)
232
+ y_pred = merge_leading_dims(y_pred)
233
+
234
+ y = torch.nn.functional.interpolate(y, (64, 64), mode='bicubic').clamp(0,1)
235
+ y_pred = torch.nn.functional.interpolate(y_pred, (64, 64), mode='bicubic').clamp(0,1)
236
+
237
+ y = (2 * y - 1)
238
+ y_pred = (2 * y_pred - 1)
239
+ global GLOBAL_LPIPS_OBJ
240
+ if GLOBAL_LPIPS_OBJ is None:
241
+ GLOBAL_LPIPS_OBJ = lp.LPIPS(net=net).to(y.device)
242
+ return GLOBAL_LPIPS_OBJ(y_pred, y).mean()
243
+
244
+ def tfpn(y_pred, y, threshold, radius=1):
245
+ '''
246
+ convert to cpu, and merge the first two dimensions
247
+ '''
248
+ y = merge_leading_dims(y)
249
+ y_pred = merge_leading_dims(y_pred)
250
+ with torch.no_grad():
251
+ if radius > 1:
252
+ pool = nn.MaxPool2d(radius)
253
+ y = pool(y)
254
+ y_pred = pool(y_pred)
255
+ y = torch.where(y >= threshold, 1, 0)
256
+ y_pred = torch.where(y_pred >= threshold, 1, 0)
257
+ mat = torchmetrics.functional.confusion_matrix(y_pred, y, task='binary', threshold=threshold)
258
+ (tn, fp), (fn, tp) = to_cpu_tensor(mat)
259
+ return tp, tn, fp, fn
260
+
261
+ def tfpn_pool(y_pred, y, threshold, radius):
262
+ y_pred = merge_leading_dims(y_pred)
263
+ y = merge_leading_dims(y)
264
+ pool = nn.MaxPool2d(radius, stride=radius//4 if radius//4 > 0 else radius)
265
+ with torch.no_grad():
266
+ y = torch.where(y>=threshold, 1, 0).float()
267
+ y_pred = torch.where(y_pred>=threshold, 1, 0).float()
268
+ y = pool(y)
269
+ y_pred = pool(y_pred)
270
+ mat = torchmetrics.functional.confusion_matrix(y_pred, y, task='binary', threshold=threshold)
271
+ (tn, fp), (fn, tp) = to_cpu_tensor(mat)
272
+ return tp, tn, fp, fn
273
+
274
+ def csi(tp, tn, fp, fn):
275
+ '''Critical Success Index. The larger the better.'''
276
+ if (tp + fn + fp) < 1e-7:
277
+ return 0.
278
+ return tp / (tp + fn + fp)
279
+
280
+ def hss(tp, tn, fp, fn):
281
+ '''Heidke Skill Score. (-inf, 1]. Larger better.'''
282
+ if (tp+fn)*(fn+tn) + (tp+fp)*(fp+tn) == 0:
283
+ return 0.
284
+ return 2 * (tp*tn - fp*fn) / ((tp+fn)*(fn+tn) + (tp+fp)*(fp+tn))
285
+
286
+ # =======================================================================
287
+ # Data Visualization
288
+ # =======================================================================
289
+
290
+ def torch_visualize(sequences, savedir=None, horizontal=10, vmin=0, vmax=1):
291
+ '''
292
+ input: sequences, a list/dict of numpy/torch arrays with shape (B, T, C, H, W)
293
+ C is assumed to be 1 and squeezed
294
+ If batch > 1, only the first sequence will be printed
295
+ '''
296
+ # First pass: compute the vertical height and convert to proper format
297
+ vertical = 0
298
+ display_texts = []
299
+ if (type(sequences) is dict):
300
+ temp = []
301
+ for k, v in sequences.items():
302
+ vertical += int(np.ceil(v.shape[1] / horizontal))
303
+ temp.append(v)
304
+ display_texts.append(k)
305
+ sequences = temp
306
+ else:
307
+ for i, sequence in enumerate(sequences):
308
+ vertical += int(np.ceil(sequence.shape[1] / horizontal))
309
+ display_texts.append(f'Item {i+1}')
310
+ sequences = to_cpu_tensor(*sequences)
311
+ # Plot the sequences
312
+ j = 0
313
+ fig, axes = plt.subplots(vertical, horizontal, figsize=(2*horizontal, 2*vertical), tight_layout=True)
314
+ plt.setp(axes, xticks=[], yticks=[])
315
+ for k, sequence in enumerate(sequences):
316
+ # only take the first batch, now seq[0] is the temporal dim
317
+ sequence = sequence[0].squeeze() # (T, H, W)
318
+ axes[j, 0].set_ylabel(display_texts[k])
319
+ for i, frame in enumerate(sequence):
320
+ j_shift = j + i // horizontal
321
+ i_shift = i % horizontal
322
+ axes[j_shift, i_shift].imshow(frame, vmin=vmin, vmax=vmax, cmap='gray')
323
+ j += int(np.ceil(sequence.shape[0] / horizontal))
324
+ if savedir:
325
+ plt.savefig(savedir + '' if savedir.endswith('.png') else '.png')
326
+ plt.close()
327
+ else:
328
+ plt.show()
329
+
330
+ """ Visualize function with colorbar and a line seprate input and output """
331
+ def color_visualize(sequences, savedir='', horizontal=5, skip=1, ypos=0):
332
+ '''
333
+ input: sequences, a list/dict of numpy/torch arrays with shape (B, T, C, H, W)
334
+ C is assumed to be 1 and squeezed
335
+ If batch > 1, only the first sequence will be printed
336
+ '''
337
+ plt.style.use(['science', 'no-latex'])
338
+ VIL_COLORS = [[0, 0, 0],
339
+ [0.30196078431372547, 0.30196078431372547, 0.30196078431372547],
340
+ [0.1568627450980392, 0.7450980392156863, 0.1568627450980392],
341
+ [0.09803921568627451, 0.5882352941176471, 0.09803921568627451],
342
+ [0.0392156862745098, 0.4117647058823529, 0.0392156862745098],
343
+ [0.0392156862745098, 0.29411764705882354, 0.0392156862745098],
344
+ [0.9607843137254902, 0.9607843137254902, 0.0],
345
+ [0.9294117647058824, 0.6745098039215687, 0.0],
346
+ [0.9411764705882353, 0.43137254901960786, 0.0],
347
+ [0.6274509803921569, 0.0, 0.0],
348
+ [0.9058823529411765, 0.0, 1.0]]
349
+
350
+ VIL_LEVELS = [0.0, 16.0, 31.0, 59.0, 74.0, 100.0, 133.0, 160.0, 181.0, 219.0, 255.0]
351
+
352
+ # First pass: compute the vertical height and convert to proper format
353
+ vertical = 0
354
+ display_texts = []
355
+ if (type(sequences) is dict):
356
+ temp = []
357
+ for k, v in sequences.items():
358
+ vertical += int(np.ceil(v.shape[1] / horizontal))
359
+ temp.append(v)
360
+ display_texts.append(k)
361
+ sequences = temp
362
+ else:
363
+ for i, sequence in enumerate(sequences):
364
+ vertical += int(np.ceil(sequence.shape[1] / horizontal))
365
+ display_texts.append(f'Item {i+1}')
366
+ sequences = to_cpu_tensor(*sequences)
367
+ # Plot the sequences
368
+ j = 0
369
+ fig, axes = plt.subplots(vertical, horizontal, figsize=(2*horizontal, 2*vertical), tight_layout=True)
370
+ plt.subplots_adjust(hspace=0.0, wspace=0.0) # tight layout
371
+ plt.setp(axes, xticks=[], yticks=[])
372
+ for k, sequence in enumerate(sequences):
373
+ # only take the first batch, now seq[0] is the temporal dim
374
+ sequence = sequence[0].squeeze() # (T, H, W)
375
+
376
+ ## =================
377
+ # = labels of time =
378
+ if k == 0:
379
+ for i in range(len(sequence)):
380
+ axes[j, i].set_xlabel(f'$t-{(len(sequence)-i)-1}$', fontsize=16)
381
+ axes[j, i].xaxis.set_label_position('top')
382
+ elif k == len(sequences)-1:
383
+ for i in range(len(sequence)):
384
+ axes[j, i].set_xlabel(f'$t+{skip*i+1}$', fontsize=16)
385
+ axes[j, i].xaxis.set_label_position('bottom')
386
+ ## =================
387
+ axes[j, 0].set_ylabel(display_texts[k], fontsize=16)
388
+ for i, frame in enumerate(sequence):
389
+ j_shift = j + i // horizontal
390
+ i_shift = i % horizontal
391
+ im = axes[j_shift, i_shift].imshow(frame*255, cmap=ListedColormap(VIL_COLORS), \
392
+ norm=BoundaryNorm(VIL_LEVELS, ListedColormap(VIL_COLORS).N))
393
+ j += int(np.ceil(sequence.shape[0] / horizontal))
394
+
395
+ ## = plot splittin line =
396
+ if ypos == 0:
397
+ ypos = 1 - 1 / len(sequences) - 0.017
398
+ fig.lines.append(Line2D((0, 1), (ypos, ypos), transform=fig.transFigure, ls='--', linewidth=2, color='#444'))
399
+ # color bar
400
+ cax = fig.add_axes([1, 0.05, 0.02, 0.5])
401
+ fig.colorbar(im, cax=cax)
402
+ ## =================
403
+ if savedir:
404
+ plt.savefig(savedir + '' if len(savedir)>0 else 'out.png')
405
+ plt.close()
406
+ else:
407
+ plt.show()
408
+
409
+ from tempfile import NamedTemporaryFile
410
+
411
+ """ Visualize function with colorbar and a line seprate input and output """
412
+ def gradio_visualize(sequences, horizontal=5, skip=1, ypos=0):
413
+ '''
414
+ input: sequences, a list/dict of numpy/torch arrays with shape (B, T, C, H, W)
415
+ C is assumed to be 1 and squeezed
416
+ If batch > 1, only the first sequence will be printed
417
+ '''
418
+ plt.style.use(['science', 'no-latex'])
419
+ VIL_COLORS = [[0, 0, 0],
420
+ [0.30196078431372547, 0.30196078431372547, 0.30196078431372547],
421
+ [0.1568627450980392, 0.7450980392156863, 0.1568627450980392],
422
+ [0.09803921568627451, 0.5882352941176471, 0.09803921568627451],
423
+ [0.0392156862745098, 0.4117647058823529, 0.0392156862745098],
424
+ [0.0392156862745098, 0.29411764705882354, 0.0392156862745098],
425
+ [0.9607843137254902, 0.9607843137254902, 0.0],
426
+ [0.9294117647058824, 0.6745098039215687, 0.0],
427
+ [0.9411764705882353, 0.43137254901960786, 0.0],
428
+ [0.6274509803921569, 0.0, 0.0],
429
+ [0.9058823529411765, 0.0, 1.0]]
430
+
431
+ VIL_LEVELS = [0.0, 16.0, 31.0, 59.0, 74.0, 100.0, 133.0, 160.0, 181.0, 219.0, 255.0]
432
+
433
+ # First pass: compute the vertical height and convert to proper format
434
+ vertical = 0
435
+ display_texts = []
436
+ if (type(sequences) is dict):
437
+ temp = []
438
+ for k, v in sequences.items():
439
+ vertical += int(np.ceil(v.shape[1] / horizontal))
440
+ temp.append(v)
441
+ display_texts.append(k)
442
+ sequences = temp
443
+ else:
444
+ for i, sequence in enumerate(sequences):
445
+ vertical += int(np.ceil(sequence.shape[1] / horizontal))
446
+ display_texts.append(f'Item {i+1}')
447
+ sequences = to_cpu_tensor(*sequences)
448
+ # Plot the sequences
449
+ j = 0
450
+ fig, axes = plt.subplots(vertical, horizontal, figsize=(2*horizontal, 2*vertical), tight_layout=True)
451
+ plt.subplots_adjust(hspace=0.0, wspace=0.0) # tight layout
452
+ plt.setp(axes, xticks=[], yticks=[])
453
+ for k, sequence in enumerate(sequences):
454
+ # only take the first batch, now seq[0] is the temporal dim
455
+ sequence = sequence.squeeze() # (T, H, W)
456
+
457
+ ## =================
458
+ # = labels of time =
459
+ if k == 0:
460
+ for i in range(len(sequence)):
461
+ axes[j, i].set_xlabel(f'$t-{(len(sequence)-i)-1}$', fontsize=16)
462
+ axes[j, i].xaxis.set_label_position('top')
463
+ elif k == len(sequences)-1:
464
+ for i in range(len(sequence)):
465
+ axes[j, i].set_xlabel(f'$t+{skip*i+1}$', fontsize=16)
466
+ axes[j, i].xaxis.set_label_position('bottom')
467
+ ## =================
468
+ axes[j, 0].set_ylabel(display_texts[k], fontsize=16)
469
+ for i, frame in enumerate(sequence):
470
+ j_shift = j + i // horizontal
471
+ i_shift = i % horizontal
472
+ im = axes[j_shift, i_shift].imshow(frame*255, cmap=ListedColormap(VIL_COLORS), \
473
+ norm=BoundaryNorm(VIL_LEVELS, ListedColormap(VIL_COLORS).N))
474
+ j += int(np.ceil(sequence.shape[0] / horizontal))
475
+
476
+ ## = plot splittin line =
477
+ if ypos == 0:
478
+ ypos = 1 - 1 / len(sequences) - 0.017
479
+ fig.lines.append(Line2D((0, 1), (ypos, ypos), transform=fig.transFigure, ls='--', linewidth=2, color='#444'))
480
+ # color bar
481
+ cax = fig.add_axes([1, 0.05, 0.02, 0.5])
482
+ fig.colorbar(im, cax=cax)
483
+
484
+ # Save the figure to a temporary file
485
+ with NamedTemporaryFile(suffix=".png", delete=False) as ff:
486
+ fig.savefig(ff.name)
487
+ file_path = ff.name
488
+
489
+ # It's important to close the figure to prevent memory leaks
490
+ plt.close(fig)
491
+
492
+ return file_path
493
+
494
+ import matplotlib.animation as animation
495
+
496
+ def gradio_gif(sequences, T):
497
+ '''
498
+ input: sequences, a list/dict of numpy/torch arrays with shape (B, T, C, H, W)
499
+ C is assumed to be 1 and squeezed
500
+ If batch > 1, only the first sequence will be printed
501
+ '''
502
+ plt.style.use(['science', 'no-latex'])
503
+ VIL_COLORS = [[0, 0, 0],
504
+ [0.30196078431372547, 0.30196078431372547, 0.30196078431372547],
505
+ [0.1568627450980392, 0.7450980392156863, 0.1568627450980392],
506
+ [0.09803921568627451, 0.5882352941176471, 0.09803921568627451],
507
+ [0.0392156862745098, 0.4117647058823529, 0.0392156862745098],
508
+ [0.0392156862745098, 0.29411764705882354, 0.0392156862745098],
509
+ [0.9607843137254902, 0.9607843137254902, 0.0],
510
+ [0.9294117647058824, 0.6745098039215687, 0.0],
511
+ [0.9411764705882353, 0.43137254901960786, 0.0],
512
+ [0.6274509803921569, 0.0, 0.0],
513
+ [0.9058823529411765, 0.0, 1.0]]
514
+
515
+ VIL_LEVELS = [0.0, 16.0, 31.0, 59.0, 74.0, 100.0, 133.0, 160.0, 181.0, 219.0, 255.0]
516
+
517
+ horizontal = len(sequences)
518
+ fig_size = 3
519
+ fig, axes = plt.subplots(nrows=1, ncols=horizontal, figsize=(fig_size*horizontal, fig_size), tight_layout=True)
520
+ plt.subplots_adjust(hspace=0.0, wspace=0.0) # tight layout
521
+ plt.setp(axes, xticks=[], yticks=[])
522
+ for i, sequence in enumerate(sequences.values()):
523
+ axes[i].set_xticks([])
524
+ axes[i].set_yticks([])
525
+ axes[i].set_xlabel(f'Ensemble {i+1}', fontsize=12)
526
+ frame = sequence[0].squeeze()
527
+ im = axes[i].imshow(frame*255, cmap=ListedColormap(VIL_COLORS), \
528
+ norm=BoundaryNorm(VIL_LEVELS, ListedColormap(VIL_COLORS).N), animated=True)
529
+
530
+ title = fig.suptitle('', y=0.9, x=0.505, fontsize=16) # Initialize an empty super title
531
+
532
+ fig.colorbar(im)
533
+
534
+ def animate(t):
535
+ for i, sequence in enumerate(sequences.values()):
536
+ frame = sequence[t].squeeze()
537
+ im = axes[i].imshow(frame*255, cmap=ListedColormap(VIL_COLORS), \
538
+ norm=BoundaryNorm(VIL_LEVELS, ListedColormap(VIL_COLORS).N), animated=True)
539
+ plt.subplots_adjust(hspace=0.0, wspace=0.0) # tight layout
540
+
541
+ title.set_text(f'$t + {t}$') # update the title text
542
+
543
+ return fig,
544
+
545
+ ani = animation.FuncAnimation(fig, animate, frames=T, interval=750, blit=True, repeat_delay=50,)
546
+
547
+ # Save the figure to a temporary file
548
+ with NamedTemporaryFile(suffix=".gif", delete=False) as ff:
549
+ ani.save(ff.name, writer='pillow', fps=5)
550
+ file_path = ff.name
551
+
552
+ plt.close()
553
+ return file_path
554
+
555
+ # import matplotlib.pyplot as plt
556
+ # import matplotlib.animation as animation
557
+ # def make_gif(frames, save_path):
558
+ # fig, ax = plt.subplots(figsize=(4,4))
559
+ # im = ax.imshow(frames[0].squeeze(), cmap='gray', vmin=0, vmax=1, animated=True)
560
+ # ax.set_axis_off()
561
+
562
+ # def update(i):
563
+ # im.set_array(frames[i].squeeze())
564
+ # return im,
565
+ # animation_fig =
566
+ # animation_fig.save(f"./{save_path}.gif")