Add Application file
Browse files- LICENSE +21 -0
- app.py +56 -0
- requirements.txt +147 -0
- stldm/__init__.py +5 -0
- stldm/__pycache__/__init__.cpython-38.pyc +0 -0
- stldm/__pycache__/__init__.cpython-39.pyc +0 -0
- stldm/__pycache__/config.cpython-38.pyc +0 -0
- stldm/__pycache__/inference.cpython-38.pyc +0 -0
- stldm/__pycache__/modules.cpython-38.pyc +0 -0
- stldm/__pycache__/modules.cpython-39.pyc +0 -0
- stldm/__pycache__/simvpv2.cpython-38.pyc +0 -0
- stldm/__pycache__/simvpv2.cpython-39.pyc +0 -0
- stldm/__pycache__/stldm.cpython-38.pyc +0 -0
- stldm/__pycache__/stldm.cpython-39.pyc +0 -0
- stldm/__pycache__/stldm_hf.cpython-38.pyc +0 -0
- stldm/__pycache__/stldm_spatial.cpython-38.pyc +0 -0
- stldm/__pycache__/stldm_spatial.cpython-39.pyc +0 -0
- stldm/__pycache__/submodules.cpython-38.pyc +0 -0
- stldm/__pycache__/submodules.cpython-39.pyc +0 -0
- stldm/config.py +115 -0
- stldm/inference.py +99 -0
- stldm/modules.py +126 -0
- stldm/simvpv2.py +431 -0
- stldm/stldm.py +612 -0
- stldm/stldm_hf.py +620 -0
- stldm/stldm_spatial.py +593 -0
- stldm/submodules.py +395 -0
- utilspp.py +566 -0
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 sqfoo
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
app.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import gradio as gr
|
| 4 |
+
|
| 5 |
+
from stldm import InferenceHub
|
| 6 |
+
from stldm.config import STLDM_HKO
|
| 7 |
+
from data.dutils import resize
|
| 8 |
+
from utilspp import gradio_visualize, gradio_gif
|
| 9 |
+
|
| 10 |
+
def nowcasting(file, cfg_str, ensemble_no):
|
| 11 |
+
# Model Setup
|
| 12 |
+
Forecastor = InferenceHub(
|
| 13 |
+
model_config=STLDM_HKO,
|
| 14 |
+
cfg_str=cfg_str,
|
| 15 |
+
model_type='HF'
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
# Data Preparation
|
| 19 |
+
x = torch.tensor(np.load(file.name))
|
| 20 |
+
if x.ndim not in (5, 4):
|
| 21 |
+
raise ValueError("Please specify the input has the format of (T C H W)")
|
| 22 |
+
|
| 23 |
+
if x.max() > 1:
|
| 24 |
+
x = x / 255.0
|
| 25 |
+
x = x.clamp(0, 1)
|
| 26 |
+
if x.ndim == 4:
|
| 27 |
+
x = x.unsqueeze(0)
|
| 28 |
+
x = resize(x, 128) # resize the data to 128 x 128
|
| 29 |
+
|
| 30 |
+
if x.shape[1] < 5:
|
| 31 |
+
raise ValueError("The input should have at least 5 frames for STLDM to predict")
|
| 32 |
+
x = x[0, -5:]
|
| 33 |
+
|
| 34 |
+
out = {}
|
| 35 |
+
for i in range(ensemble_no):
|
| 36 |
+
y_pred = Forecastor(input_x=x, include_mu=False)
|
| 37 |
+
out[f'Ensemble {i+1}'] = torch.cat((x, y_pred), dim=0)
|
| 38 |
+
|
| 39 |
+
figure = gradio_gif(out, len(out['Ensemble 1']))
|
| 40 |
+
|
| 41 |
+
return figure
|
| 42 |
+
|
| 43 |
+
with gr.Blocks() as demo:
|
| 44 |
+
gr.Markdown("# STLDM official demo for nowcasting")
|
| 45 |
+
gr.Markdown("Please upload the radar sequences with **at least 5 frames** in the format of .npy file, and **STLDM** will predict the future 20 frames based on the past 5 frames.")
|
| 46 |
+
gr.Markdown('Please refer to [paper](https://arxiv.org/abs/2512.21118) and [code](https://github.com/sqfoo/stldm_official) for more details about STLDM.')
|
| 47 |
+
|
| 48 |
+
file_input = gr.File(label="Upload the input radar squences", file_types=[".npy"])
|
| 49 |
+
cfg_str = gr.Slider(0.0, 2.0, value=1.0, step=0.1, label="Classifier Free Guidance Scale")
|
| 50 |
+
ensemble_no = gr.Slider(1, 10, value=2, step=1, label="How many ensemble predictions?")
|
| 51 |
+
|
| 52 |
+
output = gr.Image(label="Nowcasting Results")
|
| 53 |
+
btn = gr.Button("Forecast Now!")
|
| 54 |
+
btn.click(fn=nowcasting, inputs=[file_input, cfg_str, ensemble_no], outputs=output)
|
| 55 |
+
|
| 56 |
+
demo.launch()
|
requirements.txt
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
absl-py==2.0.0
|
| 2 |
+
antlr4-python3-runtime==4.9.3
|
| 3 |
+
anyio==4.12.0
|
| 4 |
+
argon2-cffi==25.1.0
|
| 5 |
+
argon2-cffi-bindings==25.1.0
|
| 6 |
+
arrow==1.4.0
|
| 7 |
+
asttokens==3.0.1
|
| 8 |
+
async-lru==2.0.5
|
| 9 |
+
attrs==25.4.0
|
| 10 |
+
babel==2.17.0
|
| 11 |
+
beautifulsoup4==4.14.3
|
| 12 |
+
bleach==6.2.0
|
| 13 |
+
cachetools==5.3.2
|
| 14 |
+
certifi==2023.11.17
|
| 15 |
+
cffi==2.0.0
|
| 16 |
+
charset-normalizer==3.3.2
|
| 17 |
+
comm==0.2.3
|
| 18 |
+
contourpy==1.3.0
|
| 19 |
+
cycler==0.12.1
|
| 20 |
+
debugpy==1.8.17
|
| 21 |
+
decorator==5.2.1
|
| 22 |
+
defusedxml==0.7.1
|
| 23 |
+
einops==0.8.1
|
| 24 |
+
exceptiongroup==1.3.1
|
| 25 |
+
executing==2.2.1
|
| 26 |
+
fastjsonschema==2.21.2
|
| 27 |
+
fonttools==4.45.0
|
| 28 |
+
fqdn==1.5.1
|
| 29 |
+
google-auth==2.23.4
|
| 30 |
+
google-auth-oauthlib==0.4.6
|
| 31 |
+
grpcio==1.59.3
|
| 32 |
+
h11==0.16.0
|
| 33 |
+
h5py==3.7.0
|
| 34 |
+
httpcore==1.0.9
|
| 35 |
+
httpx==0.28.1
|
| 36 |
+
idna==3.4
|
| 37 |
+
imageio==2.33.0
|
| 38 |
+
importlib-metadata==6.8.0
|
| 39 |
+
importlib_resources==6.5.2
|
| 40 |
+
ipykernel==6.31.0
|
| 41 |
+
ipython==8.18.1
|
| 42 |
+
ipywidgets==8.1.8
|
| 43 |
+
isoduration==20.11.0
|
| 44 |
+
jedi==0.19.2
|
| 45 |
+
Jinja2==3.1.6
|
| 46 |
+
joblib==1.3.2
|
| 47 |
+
json5==0.12.1
|
| 48 |
+
jsonpointer==3.0.0
|
| 49 |
+
jsonschema==4.25.1
|
| 50 |
+
jsonschema-specifications==2025.9.1
|
| 51 |
+
jupyter==1.1.1
|
| 52 |
+
jupyter-console==6.6.3
|
| 53 |
+
jupyter-events==0.12.0
|
| 54 |
+
jupyter-lsp==2.3.0
|
| 55 |
+
jupyter_client==8.6.3
|
| 56 |
+
jupyter_core==5.8.1
|
| 57 |
+
jupyter_server==2.17.0
|
| 58 |
+
jupyter_server_terminals==0.5.3
|
| 59 |
+
jupyterlab==4.5.0
|
| 60 |
+
jupyterlab_pygments==0.3.0
|
| 61 |
+
jupyterlab_server==2.28.0
|
| 62 |
+
jupyterlab_widgets==3.0.16
|
| 63 |
+
kiwisolver==1.4.5
|
| 64 |
+
lark==1.3.1
|
| 65 |
+
lpips==0.1.4
|
| 66 |
+
Markdown==3.5.1
|
| 67 |
+
MarkupSafe==2.1.3
|
| 68 |
+
matplotlib==3.9.4
|
| 69 |
+
matplotlib-inline==0.2.1
|
| 70 |
+
mistune==3.1.4
|
| 71 |
+
nbclient==0.10.2
|
| 72 |
+
nbconvert==7.16.6
|
| 73 |
+
nbformat==5.10.4
|
| 74 |
+
nest-asyncio==1.6.0
|
| 75 |
+
networkx==3.2.1
|
| 76 |
+
notebook==7.5.0
|
| 77 |
+
notebook_shim==0.2.4
|
| 78 |
+
numpy==1.24.4
|
| 79 |
+
oauthlib==3.2.2
|
| 80 |
+
omegaconf==2.3.0
|
| 81 |
+
opencv-python==4.8.0.74
|
| 82 |
+
overrides==7.7.0
|
| 83 |
+
packaging==23.2
|
| 84 |
+
pandas==1.4.3
|
| 85 |
+
pandocfilters==1.5.1
|
| 86 |
+
parso==0.8.5
|
| 87 |
+
pexpect==4.9.0
|
| 88 |
+
Pillow==10.1.0
|
| 89 |
+
platformdirs==4.4.0
|
| 90 |
+
prometheus_client==0.23.1
|
| 91 |
+
prompt_toolkit==3.0.52
|
| 92 |
+
protobuf==3.19.6
|
| 93 |
+
psutil==7.1.3
|
| 94 |
+
ptyprocess==0.7.0
|
| 95 |
+
pure_eval==0.2.3
|
| 96 |
+
pyasn1==0.5.1
|
| 97 |
+
pyasn1-modules==0.3.0
|
| 98 |
+
pycparser==2.23
|
| 99 |
+
Pygments==2.19.2
|
| 100 |
+
pyparsing==3.1.1
|
| 101 |
+
python-dateutil==2.8.2
|
| 102 |
+
python-json-logger==4.0.0
|
| 103 |
+
pytz==2023.3.post1
|
| 104 |
+
PyWavelets==1.5.0
|
| 105 |
+
PyYAML==6.0
|
| 106 |
+
pyzmq==27.1.0
|
| 107 |
+
referencing==0.36.2
|
| 108 |
+
requests==2.31.0
|
| 109 |
+
requests-oauthlib==1.3.1
|
| 110 |
+
rfc3339-validator==0.1.4
|
| 111 |
+
rfc3986-validator==0.1.1
|
| 112 |
+
rfc3987-syntax==1.1.0
|
| 113 |
+
rpds-py==0.27.1
|
| 114 |
+
rsa==4.9
|
| 115 |
+
SciencePlots==2.2.0
|
| 116 |
+
scikit-image==0.19.3
|
| 117 |
+
scikit-learn==1.1.2
|
| 118 |
+
scipy==1.9.1
|
| 119 |
+
Send2Trash==1.8.3
|
| 120 |
+
six==1.16.0
|
| 121 |
+
soupsieve==2.8
|
| 122 |
+
stack-data==0.6.3
|
| 123 |
+
tensorboard==2.9.0
|
| 124 |
+
tensorboard-data-server==0.6.1
|
| 125 |
+
tensorboard-plugin-wit==1.8.1
|
| 126 |
+
terminado==0.18.1
|
| 127 |
+
threadpoolctl==3.2.0
|
| 128 |
+
tifffile==2023.9.26
|
| 129 |
+
tinycss2==1.4.0
|
| 130 |
+
tomli==2.3.0
|
| 131 |
+
torch==1.12.1+cu116
|
| 132 |
+
torchmetrics==0.11.0
|
| 133 |
+
torchvision==0.13.1+cu116
|
| 134 |
+
tornado==6.5.2
|
| 135 |
+
tqdm==4.66.1
|
| 136 |
+
traitlets==5.14.3
|
| 137 |
+
typing_extensions==4.8.0
|
| 138 |
+
tzdata==2025.2
|
| 139 |
+
uri-template==1.3.0
|
| 140 |
+
urllib3==2.1.0
|
| 141 |
+
wcwidth==0.2.14
|
| 142 |
+
webcolors==24.11.1
|
| 143 |
+
webencodings==0.5.1
|
| 144 |
+
websocket-client==1.9.0
|
| 145 |
+
Werkzeug==3.0.1
|
| 146 |
+
widgetsnbextension==4.0.15
|
| 147 |
+
zipp==3.17.0
|
stldm/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from stldm.stldm import model_setup
|
| 2 |
+
from stldm.stldm_spatial import model_setup as spatial_setup
|
| 3 |
+
from stldm.inference import InferenceHub
|
| 4 |
+
|
| 5 |
+
n2n_setup = {'2D': spatial_setup, '3D': model_setup}
|
stldm/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (343 Bytes). View file
|
|
|
stldm/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (1.15 kB). View file
|
|
|
stldm/__pycache__/config.cpython-38.pyc
ADDED
|
Binary file (1.08 kB). View file
|
|
|
stldm/__pycache__/inference.cpython-38.pyc
ADDED
|
Binary file (3.26 kB). View file
|
|
|
stldm/__pycache__/modules.cpython-38.pyc
ADDED
|
Binary file (5.45 kB). View file
|
|
|
stldm/__pycache__/modules.cpython-39.pyc
ADDED
|
Binary file (5.4 kB). View file
|
|
|
stldm/__pycache__/simvpv2.cpython-38.pyc
ADDED
|
Binary file (15.4 kB). View file
|
|
|
stldm/__pycache__/simvpv2.cpython-39.pyc
ADDED
|
Binary file (15.2 kB). View file
|
|
|
stldm/__pycache__/stldm.cpython-38.pyc
ADDED
|
Binary file (18.4 kB). View file
|
|
|
stldm/__pycache__/stldm.cpython-39.pyc
ADDED
|
Binary file (18.4 kB). View file
|
|
|
stldm/__pycache__/stldm_hf.cpython-38.pyc
ADDED
|
Binary file (18.6 kB). View file
|
|
|
stldm/__pycache__/stldm_spatial.cpython-38.pyc
ADDED
|
Binary file (18.2 kB). View file
|
|
|
stldm/__pycache__/stldm_spatial.cpython-39.pyc
ADDED
|
Binary file (18.2 kB). View file
|
|
|
stldm/__pycache__/submodules.cpython-38.pyc
ADDED
|
Binary file (15.4 kB). View file
|
|
|
stldm/__pycache__/submodules.cpython-39.pyc
ADDED
|
Binary file (15.4 kB). View file
|
|
|
stldm/config.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
STLDM_SEVIR = {
|
| 2 |
+
'model': "stldm",
|
| 3 |
+
'pre': None,
|
| 4 |
+
'post': None,
|
| 5 |
+
'vp_param': {
|
| 6 |
+
'shape_in': (13, 1, 128, 128),
|
| 7 |
+
'shape_out': (12, 1, 128, 128),
|
| 8 |
+
'hid_S': 32,
|
| 9 |
+
'hid_T': 512,
|
| 10 |
+
'N_S': 4,
|
| 11 |
+
'N_T': 8,
|
| 12 |
+
'groups': 8,
|
| 13 |
+
'last_activation': 'sigmoid',
|
| 14 |
+
},
|
| 15 |
+
'stldm_param': {
|
| 16 |
+
'in_ch': 32,
|
| 17 |
+
'chs_mult': [1,2,4,8],
|
| 18 |
+
'num_groups': 8,
|
| 19 |
+
'heads': 4,
|
| 20 |
+
'dim_head': 32,
|
| 21 |
+
'base_ch': 64,
|
| 22 |
+
'patch_size': 16
|
| 23 |
+
},
|
| 24 |
+
'param': {
|
| 25 |
+
'timesteps': 50,
|
| 26 |
+
'sampling_timesteps': 20,
|
| 27 |
+
'objective': 'pred_v'
|
| 28 |
+
}
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
STLDM_HKO = {
|
| 32 |
+
'model': "stldm",
|
| 33 |
+
'pre': None,
|
| 34 |
+
'post': None,
|
| 35 |
+
'vp_param': {
|
| 36 |
+
'shape_in': (5, 1, 128, 128),
|
| 37 |
+
'shape_out': (20, 1, 128, 128),
|
| 38 |
+
'hid_S': 32,
|
| 39 |
+
'hid_T': 512,
|
| 40 |
+
'N_S': 4,
|
| 41 |
+
'N_T': 8,
|
| 42 |
+
'groups': 8,
|
| 43 |
+
'last_activation': 'sigmoid',
|
| 44 |
+
},
|
| 45 |
+
'stldm_param': {
|
| 46 |
+
'in_ch': 32,
|
| 47 |
+
'chs_mult': [1,2,4,8],
|
| 48 |
+
'num_groups': 8,
|
| 49 |
+
'heads': 4,
|
| 50 |
+
'dim_head': 32,
|
| 51 |
+
'base_ch': 64,
|
| 52 |
+
'patch_size': 16
|
| 53 |
+
},
|
| 54 |
+
'param': {
|
| 55 |
+
'timesteps': 50,
|
| 56 |
+
'sampling_timesteps': 20,
|
| 57 |
+
'objective': 'pred_v'
|
| 58 |
+
}
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
STLDM_METEO = {
|
| 62 |
+
'model': "stldm",
|
| 63 |
+
'pre': None,
|
| 64 |
+
'post': None,
|
| 65 |
+
'vp_param': {
|
| 66 |
+
'shape_in': (5, 1, 128, 128),
|
| 67 |
+
'shape_out': (20, 1, 128, 128),
|
| 68 |
+
'hid_S': 32,
|
| 69 |
+
'hid_T': 512,
|
| 70 |
+
'N_S': 4,
|
| 71 |
+
'N_T': 8,
|
| 72 |
+
'groups': 8,
|
| 73 |
+
'last_activation': 'sigmoid',
|
| 74 |
+
},
|
| 75 |
+
'stldm_param': {
|
| 76 |
+
'in_ch': 32,
|
| 77 |
+
'chs_mult': [1,2,4,8],
|
| 78 |
+
'num_groups': 8,
|
| 79 |
+
'heads': 4,
|
| 80 |
+
'dim_head': 32,
|
| 81 |
+
'base_ch': 64,
|
| 82 |
+
'patch_size': 16
|
| 83 |
+
},
|
| 84 |
+
'param': {
|
| 85 |
+
'timesteps': 50,
|
| 86 |
+
'sampling_timesteps': 20,
|
| 87 |
+
'objective': 'pred_v'
|
| 88 |
+
}
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
STLDM_HKO_HF = {
|
| 93 |
+
'vp_param': {
|
| 94 |
+
'shape_in': (5, 1, 128, 128),
|
| 95 |
+
'shape_out': (20, 1, 128, 128),
|
| 96 |
+
'hid_S': 32,
|
| 97 |
+
'hid_T': 512,
|
| 98 |
+
'N_S': 4,
|
| 99 |
+
'N_T': 8,
|
| 100 |
+
'groups': 8,
|
| 101 |
+
'last_activation': 'sigmoid',
|
| 102 |
+
},
|
| 103 |
+
'stldm_param': {
|
| 104 |
+
'in_ch': 32,
|
| 105 |
+
'chs_mult': [1,2,4,8],
|
| 106 |
+
'num_groups': 8,
|
| 107 |
+
'heads': 4,
|
| 108 |
+
'dim_head': 32,
|
| 109 |
+
'base_ch': 64,
|
| 110 |
+
'patch_size': 16
|
| 111 |
+
},
|
| 112 |
+
'timesteps': 50,
|
| 113 |
+
'sampling_timesteps': 20,
|
| 114 |
+
'objective': 'pred_v'
|
| 115 |
+
}
|
stldm/inference.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from typing import Tuple
|
| 3 |
+
|
| 4 |
+
from stldm.stldm import model_setup, guidance_scheduler
|
| 5 |
+
from stldm.stldm_spatial import model_setup as spatial_setup
|
| 6 |
+
from stldm.stldm_hf import GaussianDiffusion as hf_setup
|
| 7 |
+
|
| 8 |
+
n2n_setup = {'2D': spatial_setup, '3D': model_setup, 'HF': hf_setup}
|
| 9 |
+
|
| 10 |
+
class InferenceHub:
|
| 11 |
+
"""
|
| 12 |
+
Unified inference interface for STLDM models
|
| 13 |
+
|
| 14 |
+
Support local checkpoints and the checkpoint uploaded to Hugging Face.
|
| 15 |
+
|
| 16 |
+
Params:
|
| 17 |
+
- model_config: dict, the model configuration found in "stldm/model_config.py"
|
| 18 |
+
- model_ckpt: str, the path to the model checkpoint. For 'HF' model_type, this can be None.
|
| 19 |
+
- cfg_str: float, the classifier-free guidance strength. If None, no CFG is applied.
|
| 20 |
+
- model_type: str, the type of the model. Options are '2D', '3D', and 'HF'.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self, model_config, model_ckpt:str=None, cfg_str:float=None, model_type:str='3D', gpu='auto'):
|
| 24 |
+
self.input_size = model_config['vp_param']['shape_in']
|
| 25 |
+
self.sampling_steps = model_config['param']['timesteps']
|
| 26 |
+
self.model_config = self.setup_config(model_config, model_type)
|
| 27 |
+
|
| 28 |
+
self.model = self.setup_model(model_type, self.model_config, model_ckpt)
|
| 29 |
+
self.setup_cfg(cfg_str)
|
| 30 |
+
|
| 31 |
+
if gpu is not None:
|
| 32 |
+
if gpu == 'auto':
|
| 33 |
+
if torch.cuda.device_count() > 0:
|
| 34 |
+
self.model.to(device="cuda")
|
| 35 |
+
else:
|
| 36 |
+
self.model.to(device=f"cuda:{gpu}")
|
| 37 |
+
|
| 38 |
+
def setup_config(self, model_config, model_type):
|
| 39 |
+
if model_type == 'HF':
|
| 40 |
+
HF_config = {
|
| 41 |
+
'vp_param': model_config['vp_param'],
|
| 42 |
+
'stldm_param': model_config['stldm_param'],
|
| 43 |
+
**model_config['param'],
|
| 44 |
+
}
|
| 45 |
+
return HF_config
|
| 46 |
+
else:
|
| 47 |
+
return model_config
|
| 48 |
+
|
| 49 |
+
def setup_model(self, model_type, model_config, model_ckpt):
|
| 50 |
+
if model_type not in n2n_setup:
|
| 51 |
+
raise ValueError(f"model_type should be one of {str(list(n2n_setup.keys()))}")
|
| 52 |
+
|
| 53 |
+
if model_type == 'HF':
|
| 54 |
+
model = n2n_setup[model_type](**model_config).from_pretrained("sqfoo/STLDM_official")
|
| 55 |
+
else:
|
| 56 |
+
model = n2n_setup[model_type](model_config)
|
| 57 |
+
model.load_state_dict(torch.load(model_ckpt))
|
| 58 |
+
model.eval()
|
| 59 |
+
return model
|
| 60 |
+
|
| 61 |
+
def setup_cfg(self, cfg_str):
|
| 62 |
+
guidance = guidance_scheduler(sampling_step=self.sampling_steps, const=cfg_str) if cfg_str is not None else None
|
| 63 |
+
self.model.setup_guidance(guidance)
|
| 64 |
+
|
| 65 |
+
"""
|
| 66 |
+
This method performs inference on the input tensor.
|
| 67 |
+
|
| 68 |
+
Params:
|
| 69 |
+
- input_x: torch.tensor, the input tensor with shape (B T C H W) or (T C H W)
|
| 70 |
+
- include_mu: bool, whether to return the intermediate representation 'mu' along with the final prediction
|
| 71 |
+
"""
|
| 72 |
+
@torch.no_grad()
|
| 73 |
+
def __call__(self, input_x: torch.tensor, include_mu: bool = False):
|
| 74 |
+
ndim = input_x.ndim
|
| 75 |
+
if ndim not in (5, 4):
|
| 76 |
+
raise ValueError("Please specify the input has the either format of (B T C H W) or (T C H W)")
|
| 77 |
+
input_device = input_x.device
|
| 78 |
+
|
| 79 |
+
if ndim == 4:
|
| 80 |
+
input_x = input_x.unsqueeze(0)
|
| 81 |
+
|
| 82 |
+
if input_x.shape[1:] != self.input_size:
|
| 83 |
+
raise ValueError(f"Ensure that the input has the shape of {str(self.input_size)}")
|
| 84 |
+
|
| 85 |
+
input_x = input_x.to(self.model.device)
|
| 86 |
+
if include_mu:
|
| 87 |
+
y_pred, mu = self.model(input_x, includ_mu=include_mu)
|
| 88 |
+
else:
|
| 89 |
+
y_pred = self.model(input_x, includ_mu=include_mu)
|
| 90 |
+
mu = None
|
| 91 |
+
|
| 92 |
+
if mu is not None:
|
| 93 |
+
mu = mu.to(input_device)
|
| 94 |
+
y_pred = y_pred.to(input_device)
|
| 95 |
+
|
| 96 |
+
if ndim == 4:
|
| 97 |
+
y_pred = y_pred[0]
|
| 98 |
+
mu = mu if mu is None else mu[0]
|
| 99 |
+
return (y_pred, mu) if include_mu else y_pred
|
stldm/modules.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
|
| 4 |
+
from stldm.submodules import ChannelConversion
|
| 5 |
+
from stldm.simvpv2 import stride_generator, ConvSC, MidMetaNet
|
| 6 |
+
|
| 7 |
+
class Encoder(nn.Module):
|
| 8 |
+
def __init__(self, C_in, C_hid, N_S):
|
| 9 |
+
super(Encoder, self).__init__()
|
| 10 |
+
strides = stride_generator(N_S)
|
| 11 |
+
self.enc = nn.Sequential(
|
| 12 |
+
ConvSC(C_in, C_hid, stride=strides[0]),
|
| 13 |
+
*[ConvSC(C_hid, C_hid, stride=s) for s in strides[1:]],
|
| 14 |
+
ChannelConversion(C_hid, 2*C_hid)
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
def forward(self, x):
|
| 18 |
+
for encoder in self.enc:
|
| 19 |
+
x = encoder(x)
|
| 20 |
+
(mean, log_var) = torch.chunk(x, 2, dim=1)
|
| 21 |
+
return mean, log_var
|
| 22 |
+
|
| 23 |
+
class Decoder(nn.Module):
|
| 24 |
+
def __init__(self, C_hid, C_out, N_S, last_activation='sigmoid'):
|
| 25 |
+
super(Decoder,self).__init__()
|
| 26 |
+
strides = stride_generator(N_S, reverse=True)
|
| 27 |
+
self.dec = nn.Sequential(
|
| 28 |
+
ChannelConversion(C_hid, C_hid),
|
| 29 |
+
*[ConvSC(C_hid, C_hid, stride=s, transpose=True) for s in strides[:-1]],
|
| 30 |
+
ConvSC(C_hid, C_hid, stride=strides[-1], transpose=True)# Modify HERE
|
| 31 |
+
)
|
| 32 |
+
self.readout = nn.Conv2d(C_hid, C_out, 1)
|
| 33 |
+
if last_activation=='sigmoid':
|
| 34 |
+
self.last = nn.Sigmoid()
|
| 35 |
+
else:
|
| 36 |
+
self.last = nn.Identity()
|
| 37 |
+
|
| 38 |
+
def forward(self, x):
|
| 39 |
+
for decoder in self.dec:
|
| 40 |
+
x = decoder(x)
|
| 41 |
+
Y = self.readout(x)
|
| 42 |
+
return self.last(Y)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class VAE(nn.Module):
|
| 46 |
+
def __init__(self, C_in, hid_S, N_S, last_activation='none'):
|
| 47 |
+
super(VAE, self).__init__()
|
| 48 |
+
self.encoder = Encoder(C_in, hid_S, N_S)
|
| 49 |
+
self.decoder = Decoder(hid_S, C_in, N_S, last_activation)
|
| 50 |
+
|
| 51 |
+
def sample_from_standard_normal(self, mean, log_var):
|
| 52 |
+
std = (0.5 * log_var).exp()
|
| 53 |
+
return mean + std * torch.randn_like(mean)
|
| 54 |
+
|
| 55 |
+
def encode(self, x):
|
| 56 |
+
assert x.ndim==4
|
| 57 |
+
mean, log_var = self.encoder(x)
|
| 58 |
+
return mean, log_var
|
| 59 |
+
|
| 60 |
+
def decode(self, z):
|
| 61 |
+
assert z.ndim==4
|
| 62 |
+
dec = self.decoder(z)
|
| 63 |
+
return dec
|
| 64 |
+
|
| 65 |
+
def kl_from_standard_normal(self, mean, log_var):
|
| 66 |
+
kl = 0.5 * (log_var.exp() + mean.square() - 1.0 - log_var)
|
| 67 |
+
return kl.mean()
|
| 68 |
+
|
| 69 |
+
def _losses_(self, x, y):
|
| 70 |
+
mean, log_var = self.encode(x)
|
| 71 |
+
kl_loss = self.kl_from_standard_normal(mean, log_var)
|
| 72 |
+
|
| 73 |
+
y_pred = self.forward(x)
|
| 74 |
+
recon_loss = nn.MSELoss()(y_pred, y)
|
| 75 |
+
return recon_loss, kl_loss
|
| 76 |
+
|
| 77 |
+
def forward(self, x):
|
| 78 |
+
mu_z, log_var = self.encode(x)
|
| 79 |
+
|
| 80 |
+
z = self.sample_from_standard_normal(mu_z, log_var)
|
| 81 |
+
recon = self.decode(z)
|
| 82 |
+
return recon
|
| 83 |
+
|
| 84 |
+
class SimVPV2_Model(nn.Module):
|
| 85 |
+
def __init__(self, shape_in, shape_out, hid_S=16, hid_T=256, N_S=4, N_T=4,
|
| 86 |
+
mlp_ratio=8., drop=0.0, drop_path=0.0, spatio_kernel_enc=3,
|
| 87 |
+
spatio_kernel_dec=3, last_activation='none', act_inplace=True, **kwargs):
|
| 88 |
+
super(SimVPV2_Model, self).__init__()
|
| 89 |
+
T, C, H, W = shape_in # T is pre_seq_length
|
| 90 |
+
T2, C2, H2, W2 = shape_out # T2 is output length
|
| 91 |
+
assert C==C2 and H==H2 and W==W2, 'Need to be the same image shape for input and output'
|
| 92 |
+
self.T2 = T2
|
| 93 |
+
self.T = T
|
| 94 |
+
|
| 95 |
+
H, W = int(H / 2**(N_S/2)), int(W / 2**(N_S/2)) # downsample 1 / 2**(N_S/2)
|
| 96 |
+
|
| 97 |
+
self.vae = VAE(C_in=C, hid_S=hid_S, N_S=N_S, last_activation=last_activation)
|
| 98 |
+
self.hid = MidMetaNet(T*hid_S, T2*hid_S*2, hid_T, N_T,
|
| 99 |
+
input_resolution=(H, W), model_type='gsta',
|
| 100 |
+
mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path)
|
| 101 |
+
|
| 102 |
+
def forward(self, x_raw):
|
| 103 |
+
B, T, C, H, W = x_raw.shape
|
| 104 |
+
x = x_raw.reshape(B*T, C, H, W)
|
| 105 |
+
|
| 106 |
+
embed, log_var = self.vae.encode(x)
|
| 107 |
+
embed = self.vae.sample_from_standard_normal(embed, log_var)
|
| 108 |
+
*_, C_, H_, W_ = embed.shape
|
| 109 |
+
z = embed.view(B, T, C_, H_, W_)
|
| 110 |
+
|
| 111 |
+
hid, *_ = self.hid(z)
|
| 112 |
+
hid_mu, log_var_hid = torch.chunk(hid, 2, dim=1)
|
| 113 |
+
hid = self.vae.sample_from_standard_normal(hid_mu, log_var_hid)
|
| 114 |
+
|
| 115 |
+
hid = hid.reshape(B*self.T2, C_, H_, W_)
|
| 116 |
+
# conds_ = hid
|
| 117 |
+
conds_ = hid_mu.reshape(B*self.T2, C_, H_, W_)
|
| 118 |
+
|
| 119 |
+
Y = self.vae.decode(hid)
|
| 120 |
+
Y = Y.reshape(B, self.T2, C, H, W)
|
| 121 |
+
return Y, conds_
|
| 122 |
+
|
| 123 |
+
def _losses_(self, x, y):
|
| 124 |
+
y_pred, *_ = self.forward(x)
|
| 125 |
+
recon_loss = nn.MSELoss()(y_pred, y)
|
| 126 |
+
return recon_loss
|
stldm/simvpv2.py
ADDED
|
@@ -0,0 +1,431 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
import torch, math
|
| 3 |
+
|
| 4 |
+
# from torchmodels.simvp import ConvSC, stride_generator
|
| 5 |
+
|
| 6 |
+
class SimVPV2_Model(nn.Module):
|
| 7 |
+
r"""SimVP Model
|
| 8 |
+
|
| 9 |
+
Implementation of `SimVP: Simpler yet Better Video Prediction
|
| 10 |
+
Just Remove The Skip Connection
|
| 11 |
+
<https://arxiv.org/abs/2206.05099>`_.
|
| 12 |
+
|
| 13 |
+
"""
|
| 14 |
+
def __init__(self, shape_in, shape_out, hid_S=16, hid_T=256, N_S=4, N_T=4,
|
| 15 |
+
mlp_ratio=8., drop=0.0, drop_path=0.0, spatio_kernel_enc=3,
|
| 16 |
+
spatio_kernel_dec=3, last_activation='none', act_inplace=True, recursive=False, **kwargs):
|
| 17 |
+
super(SimVPV2_Model, self).__init__()
|
| 18 |
+
T, C, H, W = shape_in # T is pre_seq_length
|
| 19 |
+
T2, C2, H2, W2 = shape_out # T2 is output length
|
| 20 |
+
assert C==C2 and H==H2 and W==W2, 'Need to be the same image shape for input and output'
|
| 21 |
+
self.T2 = T2
|
| 22 |
+
self.T = T
|
| 23 |
+
|
| 24 |
+
H, W = int(H / 2**(N_S/2)), int(W / 2**(N_S/2)) # downsample 1 / 2**(N_S/2)
|
| 25 |
+
act_inplace = False
|
| 26 |
+
|
| 27 |
+
self.enc = Encoder(C, hid_S, N_S)#, spatio_kernel_enc, act_inplace=act_inplace)
|
| 28 |
+
self.dec = Decoder(hid_S, C, N_S, last_activation)#, spatio_kernel_dec, act_inplace=act_inplace)
|
| 29 |
+
|
| 30 |
+
# Modify HERE
|
| 31 |
+
self.recursive = recursive
|
| 32 |
+
if not self.recursive:
|
| 33 |
+
self.hid = MidMetaNet(T*hid_S, T2*hid_S, hid_T, N_T,
|
| 34 |
+
input_resolution=(H, W), model_type='gsta',
|
| 35 |
+
mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path)
|
| 36 |
+
else:
|
| 37 |
+
self.hid = MidMetaNet(T*hid_S, T*hid_S, hid_T, N_T,
|
| 38 |
+
input_resolution=(H, W), model_type='gsta',
|
| 39 |
+
mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path)
|
| 40 |
+
self.last_activation = last_activation
|
| 41 |
+
|
| 42 |
+
def forward(self, x_raw, **kwargs):
|
| 43 |
+
B, T, C, H, W = x_raw.shape
|
| 44 |
+
# x = x_raw.view(B*T, C, H, W)
|
| 45 |
+
x = x_raw.reshape(B*T, C, H, W)
|
| 46 |
+
|
| 47 |
+
embed = self.enc(x)
|
| 48 |
+
_, C_, H_, W_ = embed.shape
|
| 49 |
+
|
| 50 |
+
z = embed.view(B, T, C_, H_, W_)
|
| 51 |
+
|
| 52 |
+
if not self.recursive:
|
| 53 |
+
hid, conds_ = self.hid(z)
|
| 54 |
+
else:
|
| 55 |
+
no = self.T2//self.T
|
| 56 |
+
if self.T2%self.T != 0:
|
| 57 |
+
no += 1
|
| 58 |
+
hid = []
|
| 59 |
+
for i in range(no):
|
| 60 |
+
z, _ = self.hid(z)
|
| 61 |
+
hid.append(z)
|
| 62 |
+
hid = torch.cat(hid, dim=1)
|
| 63 |
+
hid = hid[:, :self.T2]
|
| 64 |
+
conds_ = hid.reshape(-1, C_, H_, W_)
|
| 65 |
+
# print(hid.shape, conds_.shape)
|
| 66 |
+
|
| 67 |
+
hid = hid.reshape(B*self.T2, C_, H_, W_)
|
| 68 |
+
|
| 69 |
+
Y = self.dec(hid)
|
| 70 |
+
Y = Y.reshape(B, self.T2, C, H, W)
|
| 71 |
+
return Y, conds_, hid.reshape(B, -1, C_, H_, W_)
|
| 72 |
+
|
| 73 |
+
def recon_loss(self, x, y):
|
| 74 |
+
X = torch.cat((x, y), dim=1)
|
| 75 |
+
B, T, C, H, W = X.shape
|
| 76 |
+
X = X.reshape(-1, C, H, W)
|
| 77 |
+
recon = self.dec(self.enc(X))
|
| 78 |
+
return nn.MSELoss()(recon, X)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class MidMetaNet(nn.Module):
|
| 82 |
+
"""The hidden Translator of MetaFormer for SimVP"""
|
| 83 |
+
# Modify HERE with an additional param: channel_out
|
| 84 |
+
def __init__(self, channel_in, channel_out, channel_hid, N2,
|
| 85 |
+
input_resolution=None, model_type=None,
|
| 86 |
+
mlp_ratio=4., drop=0.0, drop_path=0.1):
|
| 87 |
+
super(MidMetaNet, self).__init__()
|
| 88 |
+
assert N2 >= 2 and mlp_ratio > 1
|
| 89 |
+
self.N2 = N2
|
| 90 |
+
dpr = [ # stochastic depth decay rule
|
| 91 |
+
x.item() for x in torch.linspace(1e-2, drop_path, self.N2)]
|
| 92 |
+
|
| 93 |
+
# downsample
|
| 94 |
+
enc_layers = [MetaBlock(
|
| 95 |
+
channel_in, channel_hid, input_resolution, model_type,
|
| 96 |
+
mlp_ratio, drop, drop_path=dpr[0], layer_i=0)]
|
| 97 |
+
# middle layers
|
| 98 |
+
for i in range(1, N2-1):
|
| 99 |
+
enc_layers.append(MetaBlock(
|
| 100 |
+
channel_hid, channel_hid, input_resolution, model_type,
|
| 101 |
+
mlp_ratio, drop, drop_path=dpr[i], layer_i=i))
|
| 102 |
+
|
| 103 |
+
# upsample
|
| 104 |
+
# Modify HERE
|
| 105 |
+
enc_layers.append(MetaBlock(
|
| 106 |
+
channel_hid, channel_out, input_resolution, model_type,
|
| 107 |
+
mlp_ratio, drop, drop_path=drop_path, layer_i=N2-1))
|
| 108 |
+
self.enc = nn.Sequential(*enc_layers)
|
| 109 |
+
|
| 110 |
+
def forward(self, x):
|
| 111 |
+
B, T, C, H, W = x.shape
|
| 112 |
+
x = x.reshape(B, T*C, H, W)
|
| 113 |
+
|
| 114 |
+
z = x
|
| 115 |
+
conds = [z]
|
| 116 |
+
for i in range(self.N2):
|
| 117 |
+
z = self.enc[i](z)
|
| 118 |
+
conds.append(z)
|
| 119 |
+
|
| 120 |
+
y = z.reshape(B, -1, C, H, W)
|
| 121 |
+
return y, y.reshape(-1, C, H, W) #conds #conds[:-1]
|
| 122 |
+
|
| 123 |
+
class MetaBlock(nn.Module):
|
| 124 |
+
"""The hidden Translator of MetaFormer for SimVP"""
|
| 125 |
+
|
| 126 |
+
def __init__(self, in_channels, out_channels, input_resolution=None, model_type=None,
|
| 127 |
+
mlp_ratio=8., drop=0.0, drop_path=0.0, layer_i=0):
|
| 128 |
+
super(MetaBlock, self).__init__()
|
| 129 |
+
self.in_channels = in_channels
|
| 130 |
+
self.out_channels = out_channels
|
| 131 |
+
model_type = model_type.lower() if model_type is not None else 'gsta'
|
| 132 |
+
|
| 133 |
+
if model_type == 'gsta':
|
| 134 |
+
self.block = GASubBlock(
|
| 135 |
+
in_channels, kernel_size=21, mlp_ratio=mlp_ratio,
|
| 136 |
+
drop=drop, drop_path=drop_path, act_layer=nn.GELU)
|
| 137 |
+
else:
|
| 138 |
+
assert False and "Invalid model_type in SimVP"
|
| 139 |
+
|
| 140 |
+
if in_channels != out_channels:
|
| 141 |
+
self.reduction = nn.Conv2d(
|
| 142 |
+
in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
| 143 |
+
|
| 144 |
+
def forward(self, x):
|
| 145 |
+
z = self.block(x)
|
| 146 |
+
return z if self.in_channels == self.out_channels else self.reduction(z)
|
| 147 |
+
|
| 148 |
+
class GASubBlock(nn.Module):
|
| 149 |
+
"""A GABlock (gSTA) for SimVP"""
|
| 150 |
+
|
| 151 |
+
def __init__(self, dim, kernel_size=21, mlp_ratio=4.,
|
| 152 |
+
drop=0., drop_path=0.1, init_value=1e-2, act_layer=nn.GELU):
|
| 153 |
+
super().__init__()
|
| 154 |
+
self.norm1 = nn.BatchNorm2d(dim)
|
| 155 |
+
self.attn = SpatialAttention(dim, kernel_size)
|
| 156 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 157 |
+
|
| 158 |
+
self.norm2 = nn.BatchNorm2d(dim)
|
| 159 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 160 |
+
self.mlp = MixMlp(
|
| 161 |
+
in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 162 |
+
|
| 163 |
+
self.layer_scale_1 = nn.Parameter(init_value * torch.ones((dim)), requires_grad=True)
|
| 164 |
+
self.layer_scale_2 = nn.Parameter(init_value * torch.ones((dim)), requires_grad=True)
|
| 165 |
+
|
| 166 |
+
self.apply(self._init_weights)
|
| 167 |
+
|
| 168 |
+
def _init_weights(self, m):
|
| 169 |
+
if isinstance(m, nn.Linear):
|
| 170 |
+
trunc_normal_(m.weight, std=.02)
|
| 171 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 172 |
+
nn.init.constant_(m.bias, 0)
|
| 173 |
+
elif isinstance(m, nn.LayerNorm):
|
| 174 |
+
nn.init.constant_(m.bias, 0)
|
| 175 |
+
nn.init.constant_(m.weight, 1.0)
|
| 176 |
+
elif isinstance(m, nn.Conv2d):
|
| 177 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 178 |
+
fan_out //= m.groups
|
| 179 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
| 180 |
+
if m.bias is not None:
|
| 181 |
+
m.bias.data.zero_()
|
| 182 |
+
|
| 183 |
+
@torch.jit.ignore
|
| 184 |
+
def no_weight_decay(self):
|
| 185 |
+
return {'layer_scale_1', 'layer_scale_2'}
|
| 186 |
+
|
| 187 |
+
def forward(self, x):
|
| 188 |
+
x = x + self.drop_path(
|
| 189 |
+
self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * self.attn(self.norm1(x)))
|
| 190 |
+
x = x + self.drop_path(
|
| 191 |
+
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.mlp(self.norm2(x)))
|
| 192 |
+
return x
|
| 193 |
+
|
| 194 |
+
class SpatialAttention(nn.Module):
|
| 195 |
+
"""A Spatial Attention block for SimVP"""
|
| 196 |
+
|
| 197 |
+
def __init__(self, d_model, kernel_size=21, attn_shortcut=True):
|
| 198 |
+
super().__init__()
|
| 199 |
+
|
| 200 |
+
self.proj_1 = nn.Conv2d(d_model, d_model, 1) # 1x1 conv
|
| 201 |
+
self.activation = nn.GELU() # GELU
|
| 202 |
+
self.spatial_gating_unit = AttentionModule(d_model, kernel_size)
|
| 203 |
+
self.proj_2 = nn.Conv2d(d_model, d_model, 1) # 1x1 conv
|
| 204 |
+
self.attn_shortcut = attn_shortcut
|
| 205 |
+
|
| 206 |
+
def forward(self, x):
|
| 207 |
+
if self.attn_shortcut:
|
| 208 |
+
shortcut = x.clone()
|
| 209 |
+
x = self.proj_1(x)
|
| 210 |
+
x = self.activation(x)
|
| 211 |
+
x = self.spatial_gating_unit(x)
|
| 212 |
+
x = self.proj_2(x)
|
| 213 |
+
if self.attn_shortcut:
|
| 214 |
+
x = x + shortcut
|
| 215 |
+
return x
|
| 216 |
+
|
| 217 |
+
class AttentionModule(nn.Module):
|
| 218 |
+
"""Large Kernel Attention for SimVP"""
|
| 219 |
+
|
| 220 |
+
def __init__(self, dim, kernel_size, dilation=3):
|
| 221 |
+
super().__init__()
|
| 222 |
+
d_k = 2 * dilation - 1
|
| 223 |
+
d_p = (d_k - 1) // 2
|
| 224 |
+
dd_k = kernel_size // dilation + ((kernel_size // dilation) % 2 - 1)
|
| 225 |
+
dd_p = (dilation * (dd_k - 1) // 2)
|
| 226 |
+
|
| 227 |
+
self.conv0 = nn.Conv2d(dim, dim, d_k, padding=d_p, groups=dim)
|
| 228 |
+
self.conv_spatial = nn.Conv2d(
|
| 229 |
+
dim, dim, dd_k, stride=1, padding=dd_p, groups=dim, dilation=dilation)
|
| 230 |
+
self.conv1 = nn.Conv2d(dim, 2*dim, 1)
|
| 231 |
+
|
| 232 |
+
def forward(self, x):
|
| 233 |
+
u = x.clone()
|
| 234 |
+
attn = self.conv0(x) # depth-wise conv
|
| 235 |
+
attn = self.conv_spatial(attn) # depth-wise dilation convolution
|
| 236 |
+
|
| 237 |
+
f_g = self.conv1(attn)
|
| 238 |
+
split_dim = f_g.shape[1] // 2
|
| 239 |
+
f_x, g_x = torch.split(f_g, split_dim, dim=1)
|
| 240 |
+
return torch.sigmoid(g_x) * f_x
|
| 241 |
+
|
| 242 |
+
class DWConv(nn.Module):
|
| 243 |
+
def __init__(self, dim=768):
|
| 244 |
+
super(DWConv, self).__init__()
|
| 245 |
+
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
|
| 246 |
+
|
| 247 |
+
def forward(self, x):
|
| 248 |
+
x = self.dwconv(x)
|
| 249 |
+
return x
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
class MixMlp(nn.Module):
|
| 253 |
+
def __init__(self,
|
| 254 |
+
in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
| 255 |
+
super().__init__()
|
| 256 |
+
out_features = out_features or in_features
|
| 257 |
+
hidden_features = hidden_features or in_features
|
| 258 |
+
self.fc1 = nn.Conv2d(in_features, hidden_features, 1) # 1x1
|
| 259 |
+
self.dwconv = DWConv(hidden_features) # CFF: Convlutional feed-forward network
|
| 260 |
+
self.act = act_layer() # GELU
|
| 261 |
+
self.fc2 = nn.Conv2d(hidden_features, out_features, 1) # 1x1
|
| 262 |
+
self.drop = nn.Dropout(drop)
|
| 263 |
+
self.apply(self._init_weights)
|
| 264 |
+
|
| 265 |
+
def _init_weights(self, m):
|
| 266 |
+
if isinstance(m, nn.Linear):
|
| 267 |
+
trunc_normal_(m.weight, std=.02)
|
| 268 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 269 |
+
nn.init.constant_(m.bias, 0)
|
| 270 |
+
elif isinstance(m, nn.LayerNorm):
|
| 271 |
+
nn.init.constant_(m.bias, 0)
|
| 272 |
+
nn.init.constant_(m.weight, 1.0)
|
| 273 |
+
elif isinstance(m, nn.Conv2d):
|
| 274 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 275 |
+
fan_out //= m.groups
|
| 276 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
| 277 |
+
if m.bias is not None:
|
| 278 |
+
m.bias.data.zero_()
|
| 279 |
+
|
| 280 |
+
def forward(self, x):
|
| 281 |
+
x = self.fc1(x)
|
| 282 |
+
x = self.dwconv(x)
|
| 283 |
+
x = self.act(x)
|
| 284 |
+
x = self.drop(x)
|
| 285 |
+
x = self.fc2(x)
|
| 286 |
+
x = self.drop(x)
|
| 287 |
+
return x
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
"""
|
| 291 |
+
From TIMM repo: https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
|
| 292 |
+
"""
|
| 293 |
+
def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
|
| 294 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 295 |
+
|
| 296 |
+
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
| 297 |
+
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
| 298 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
| 299 |
+
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
| 300 |
+
'survival rate' as the argument.
|
| 301 |
+
|
| 302 |
+
"""
|
| 303 |
+
if drop_prob == 0. or not training:
|
| 304 |
+
return x
|
| 305 |
+
keep_prob = 1 - drop_prob
|
| 306 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
| 307 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
| 308 |
+
if keep_prob > 0.0 and scale_by_keep:
|
| 309 |
+
random_tensor.div_(keep_prob)
|
| 310 |
+
return x * random_tensor
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
class DropPath(nn.Module):
|
| 314 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 315 |
+
"""
|
| 316 |
+
def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
|
| 317 |
+
super(DropPath, self).__init__()
|
| 318 |
+
self.drop_prob = drop_prob
|
| 319 |
+
self.scale_by_keep = scale_by_keep
|
| 320 |
+
|
| 321 |
+
def forward(self, x):
|
| 322 |
+
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
|
| 323 |
+
|
| 324 |
+
def extra_repr(self):
|
| 325 |
+
return f'drop_prob={round(self.drop_prob,3):0.3f}'
|
| 326 |
+
|
| 327 |
+
def _trunc_normal_(tensor, mean, std, a, b):
|
| 328 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
| 329 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
| 330 |
+
def norm_cdf(x):
|
| 331 |
+
# Computes standard normal cumulative distribution function
|
| 332 |
+
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
| 333 |
+
|
| 334 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
| 335 |
+
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
| 336 |
+
"The distribution of values may be incorrect.",
|
| 337 |
+
stacklevel=2)
|
| 338 |
+
|
| 339 |
+
# Values are generated by using a truncated uniform distribution and
|
| 340 |
+
# then using the inverse CDF for the normal distribution.
|
| 341 |
+
# Get upper and lower cdf values
|
| 342 |
+
l = norm_cdf((a - mean) / std)
|
| 343 |
+
u = norm_cdf((b - mean) / std)
|
| 344 |
+
|
| 345 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
| 346 |
+
# [2l-1, 2u-1].
|
| 347 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
| 348 |
+
|
| 349 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
| 350 |
+
# standard normal
|
| 351 |
+
tensor.erfinv_()
|
| 352 |
+
|
| 353 |
+
# Transform to proper mean, std
|
| 354 |
+
tensor.mul_(std * math.sqrt(2.))
|
| 355 |
+
tensor.add_(mean)
|
| 356 |
+
|
| 357 |
+
# Clamp to ensure it's in the proper range
|
| 358 |
+
tensor.clamp_(min=a, max=b)
|
| 359 |
+
return tensor
|
| 360 |
+
|
| 361 |
+
class Encoder(nn.Module):
|
| 362 |
+
def __init__(self,C_in, C_hid, N_S):
|
| 363 |
+
super(Encoder,self).__init__()
|
| 364 |
+
strides = stride_generator(N_S)
|
| 365 |
+
self.enc = nn.Sequential(
|
| 366 |
+
ConvSC(C_in, C_hid, stride=strides[0]),
|
| 367 |
+
*[ConvSC(C_hid, C_hid, stride=s) for s in strides[1:]]
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
def forward(self,x):# B*4, 3, 128, 128
|
| 371 |
+
enc1 = self.enc[0](x)
|
| 372 |
+
latent = enc1
|
| 373 |
+
for i in range(1,len(self.enc)):
|
| 374 |
+
latent = self.enc[i](latent)
|
| 375 |
+
return latent
|
| 376 |
+
|
| 377 |
+
class Decoder(nn.Module):
|
| 378 |
+
def __init__(self,C_hid, C_out, N_S, last_activation='sigmoid'):
|
| 379 |
+
super(Decoder,self).__init__()
|
| 380 |
+
strides = stride_generator(N_S, reverse=True)
|
| 381 |
+
self.dec = nn.Sequential(
|
| 382 |
+
*[ConvSC(C_hid, C_hid, stride=s, transpose=True) for s in strides[:-1]],
|
| 383 |
+
ConvSC(C_hid, C_hid, stride=strides[-1], transpose=True)# Modify HERE
|
| 384 |
+
)
|
| 385 |
+
self.readout = nn.Conv2d(C_hid, C_out, 1)
|
| 386 |
+
if last_activation=='sigmoid':
|
| 387 |
+
self.last = nn.Sigmoid()
|
| 388 |
+
else:
|
| 389 |
+
self.last = nn.Identity()
|
| 390 |
+
|
| 391 |
+
def forward(self, hid):
|
| 392 |
+
for i in range(0,len(self.dec)-1):
|
| 393 |
+
hid = self.dec[i](hid)
|
| 394 |
+
Y = self.dec[-1](hid) # Modify HERE
|
| 395 |
+
Y = self.readout(Y)
|
| 396 |
+
return self.last(Y)
|
| 397 |
+
|
| 398 |
+
class BasicConv2d(nn.Module):
|
| 399 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, transpose=False, act_norm=False):
|
| 400 |
+
super(BasicConv2d, self).__init__()
|
| 401 |
+
self.act_norm=act_norm
|
| 402 |
+
if not transpose:
|
| 403 |
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
| 404 |
+
else:
|
| 405 |
+
self.conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,output_padding=stride //2 )
|
| 406 |
+
self.norm = nn.GroupNorm(2, out_channels)
|
| 407 |
+
self.act = nn.LeakyReLU(0.2, inplace=True)
|
| 408 |
+
|
| 409 |
+
def forward(self, x):
|
| 410 |
+
y = self.conv(x)
|
| 411 |
+
if self.act_norm:
|
| 412 |
+
y = self.act(self.norm(y))
|
| 413 |
+
return y
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
class ConvSC(nn.Module):
|
| 417 |
+
def __init__(self, C_in, C_out, stride, transpose=False, act_norm=True):
|
| 418 |
+
super(ConvSC, self).__init__()
|
| 419 |
+
if stride == 1:
|
| 420 |
+
transpose = False
|
| 421 |
+
self.conv = BasicConv2d(C_in, C_out, kernel_size=3, stride=stride,
|
| 422 |
+
padding=1, transpose=transpose, act_norm=act_norm)
|
| 423 |
+
|
| 424 |
+
def forward(self, x):
|
| 425 |
+
y = self.conv(x)
|
| 426 |
+
return y
|
| 427 |
+
|
| 428 |
+
def stride_generator(N, reverse=False):
|
| 429 |
+
strides = [1, 2]*10
|
| 430 |
+
if reverse: return list(reversed(strides[:N]))
|
| 431 |
+
else: return strides[:N]
|
stldm/stldm.py
ADDED
|
@@ -0,0 +1,612 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch, random
|
| 2 |
+
from torch import nn
|
| 3 |
+
from einops import rearrange
|
| 4 |
+
|
| 5 |
+
from stldm.submodules import *
|
| 6 |
+
|
| 7 |
+
class Down_Block(nn.Module):
|
| 8 |
+
def __init__(self, in_ch, hid_ch, out_ch, time_dim, is_last, patch_size=None, num_groups=8, heads=4, dim_head=32):
|
| 9 |
+
super(Down_Block, self).__init__()
|
| 10 |
+
self.block1 = ResnetBlock(dim=in_ch, dim_out=hid_ch, time_emb_dim=time_dim, groups=num_groups)
|
| 11 |
+
self.attn_spatial = Residual(PreNorm(hid_ch, Quadratic_SpatialAttention(dim=hid_ch, heads=heads, dim_head=dim_head))) if patch_size is None else Residual(PreNorm(hid_ch, Linear_SpatialAttention(dim=hid_ch, patch_size=patch_size, heads=heads, dim_head=dim_head)))
|
| 12 |
+
self.block2 = ResnetBlock(dim=hid_ch, dim_out=hid_ch, groups=num_groups)
|
| 13 |
+
# self.attn_temporal = Residual(PreNorm(hid_ch, TemporalAttention_Pos(dim=hid_ch, heads=heads, dim_head=dim_head)))
|
| 14 |
+
self.attn_temporal = Residual(PreNorm(hid_ch, TemporalAttention(dim=hid_ch, heads=heads, dim_head=dim_head)))
|
| 15 |
+
self.last = Downsample2D(dim_in=hid_ch, dim_out=out_ch) if not is_last else ChannelConversion(hid_ch, out_ch)
|
| 16 |
+
|
| 17 |
+
def forward(self, x, time_emb, cond=None, relative_pos=None):
|
| 18 |
+
assert x.ndim==5
|
| 19 |
+
B, T, C, H, W = x.shape
|
| 20 |
+
|
| 21 |
+
x = x.reshape(B*T, C, H, W)
|
| 22 |
+
if cond is None:
|
| 23 |
+
cond = torch.zeros_like(x) # -> Unconditioning
|
| 24 |
+
|
| 25 |
+
time_emb = time_emb.unsqueeze(1) # From (B C) to (B 1 C)
|
| 26 |
+
time_emb = time_emb.repeat(1, T, 1)
|
| 27 |
+
time_emb = time_emb.reshape(B*T, -1)
|
| 28 |
+
|
| 29 |
+
out = torch.cat((x, cond), dim=1) # BT, 2C, H, W
|
| 30 |
+
out = self.block1(out, time_emb)
|
| 31 |
+
|
| 32 |
+
spatial_attn = self.attn_spatial(out)
|
| 33 |
+
out = self.block2(spatial_attn, time_emb)
|
| 34 |
+
*_, c, h, w = out.shape
|
| 35 |
+
out = out.reshape(B,T,c,h,w)
|
| 36 |
+
|
| 37 |
+
# temporal_attn = self.attn_temporal(out, relative_pos)
|
| 38 |
+
temporal_attn = self.attn_temporal(out)
|
| 39 |
+
temporal_attn = temporal_attn.reshape(B*T,c,h,w)
|
| 40 |
+
|
| 41 |
+
out = self.last(temporal_attn)
|
| 42 |
+
*_, c, h, w = out.shape
|
| 43 |
+
|
| 44 |
+
return out.reshape(B, T, c, h, w), spatial_attn, temporal_attn
|
| 45 |
+
|
| 46 |
+
class MidBlock(nn.Module):
|
| 47 |
+
def __init__(self, in_ch, time_dim, num_groups=8, heads=4, dim_head=32):
|
| 48 |
+
super(MidBlock, self).__init__()
|
| 49 |
+
self.block1 = ResnetBlock(dim=in_ch, dim_out=in_ch, time_emb_dim=time_dim, groups=num_groups)
|
| 50 |
+
self.qattn_spatial = Residual(PreNorm(in_ch, Quadratic_SpatialAttention(dim=in_ch, heads=heads, dim_head=dim_head)))
|
| 51 |
+
self.block2 = ResnetBlock(dim=in_ch, dim_out=in_ch, time_emb_dim=time_dim, groups=num_groups)
|
| 52 |
+
# self.qattn_time = Residual(PreNorm(in_ch, TemporalAttention_Pos(dim=in_ch, heads=heads, dim_head=dim_head)))
|
| 53 |
+
self.qattn_time = Residual(PreNorm(in_ch, TemporalAttention(dim=in_ch, heads=heads, dim_head=dim_head)))
|
| 54 |
+
self.block3 = ResnetBlock(dim=in_ch, dim_out=in_ch, time_emb_dim=time_dim, groups=num_groups)
|
| 55 |
+
|
| 56 |
+
def forward(self, x, time_emb, relative_pos=None):
|
| 57 |
+
assert x.ndim==5
|
| 58 |
+
B, T, C, H, W = x.shape
|
| 59 |
+
x = x.reshape(B*T, C, H, W)
|
| 60 |
+
|
| 61 |
+
time_emb = time_emb.unsqueeze(1) # From (B C) to (B 1 C)
|
| 62 |
+
time_emb = time_emb.repeat(1, T, 1)
|
| 63 |
+
time_emb = time_emb.reshape(B*T, -1)
|
| 64 |
+
|
| 65 |
+
out = self.block1(x, time_emb)
|
| 66 |
+
out = self.qattn_spatial(out)
|
| 67 |
+
out = self.block2(out, time_emb) # a little bit difference here
|
| 68 |
+
|
| 69 |
+
out = out.reshape((B, T, C, H, W))
|
| 70 |
+
# out = self.qattn_time(out, relative_pos).reshape(B*T, C, H, W)
|
| 71 |
+
out = self.qattn_time(out).reshape(B*T, C, H, W)
|
| 72 |
+
out = self.block3(out, time_emb)
|
| 73 |
+
|
| 74 |
+
*_, c, h, w = out.shape
|
| 75 |
+
return out.reshape(B, T, c, h, w)
|
| 76 |
+
|
| 77 |
+
class Up_Block(nn.Module):
|
| 78 |
+
def __init__(self, in_chs, hid_ch, out_ch, is_last, time_dim, patch_size=None, num_groups=8, heads=4, dim_head=32):
|
| 79 |
+
super(Up_Block, self).__init__()
|
| 80 |
+
in_ch, skip_ch = in_chs
|
| 81 |
+
self.up = Upsample2D(dim_in=in_ch, dim_out=hid_ch) if not is_last else ChannelConversion(in_ch, hid_ch)
|
| 82 |
+
self.attn_spatial = Residual(PreNorm(hid_ch, Quadratic_SpatialAttention(dim=hid_ch, heads=heads, dim_head=dim_head) if patch_size is None else Linear_SpatialAttention(dim=hid_ch, patch_size=patch_size, heads=heads, dim_head=dim_head)))
|
| 83 |
+
self.block1 = ResnetBlock(dim=hid_ch+skip_ch, dim_out=hid_ch, time_emb_dim=time_dim, groups=num_groups)
|
| 84 |
+
# self.attn_temporal = Residual(PreNorm(hid_ch, TemporalAttention_Pos(dim=hid_ch, heads=heads, dim_head=dim_head)))
|
| 85 |
+
self.attn_temporal = Residual(PreNorm(hid_ch, TemporalAttention(dim=hid_ch, heads=heads, dim_head=dim_head)))
|
| 86 |
+
self.block2 = ResnetBlock(dim=hid_ch+skip_ch, dim_out=out_ch, time_emb_dim=time_dim, groups=num_groups)
|
| 87 |
+
|
| 88 |
+
def forward(self, x, time_emb, spatialattn_skip, tempattn_skip, relative_pos=None):
|
| 89 |
+
assert x.ndim==5
|
| 90 |
+
B, T, C, H, W = x.shape
|
| 91 |
+
x = x.reshape(B*T, C, H, W)
|
| 92 |
+
|
| 93 |
+
time_emb = time_emb.unsqueeze(1) # From (B C) to (B 1 C)
|
| 94 |
+
time_emb = time_emb.repeat(1, T, 1)
|
| 95 |
+
time_emb = time_emb.reshape(B*T, -1)
|
| 96 |
+
|
| 97 |
+
out = self.up(x)
|
| 98 |
+
*_, c, h, w = out.shape
|
| 99 |
+
out = out.reshape(-1, T, c, h, w)
|
| 100 |
+
|
| 101 |
+
# out = self.attn_temporal(out, relative_pos).reshape(B*T, c, h, w)
|
| 102 |
+
out = self.attn_temporal(out).reshape(B*T, c, h, w)
|
| 103 |
+
|
| 104 |
+
out = torch.cat((out, tempattn_skip), dim=1)
|
| 105 |
+
out = self.block1(out, time_emb)
|
| 106 |
+
|
| 107 |
+
out = self.attn_spatial(out)
|
| 108 |
+
|
| 109 |
+
out = torch.cat((out, spatialattn_skip), dim=1)
|
| 110 |
+
out = self.block2(out, time_emb)
|
| 111 |
+
*_, c, h, w = out.shape
|
| 112 |
+
return out.reshape(B, T, c, h, w)
|
| 113 |
+
|
| 114 |
+
class LDM(nn.Module):
|
| 115 |
+
def __init__(self, in_ch, chs_mult:tuple, patch_size=None, num_groups=8, heads=4, dim_head=32, base_ch=64):
|
| 116 |
+
super(LDM, self).__init__()
|
| 117 |
+
# Time Embedding MLP
|
| 118 |
+
time_dim = 4*base_ch
|
| 119 |
+
fourier_dim = base_ch
|
| 120 |
+
self.time_mlp = Time_MLP(dim=base_ch, time_dim=time_dim, fourier_dim=fourier_dim)
|
| 121 |
+
|
| 122 |
+
ups, downs = [], []
|
| 123 |
+
conditions = []
|
| 124 |
+
|
| 125 |
+
layer_no = len(chs_mult)
|
| 126 |
+
chs = [in_ch, *map(lambda m: base_ch*m, chs_mult)]
|
| 127 |
+
ch_in, ch_out = chs[:-1], chs[1:]
|
| 128 |
+
up_in, up_out = list(reversed(ch_out)), list(reversed(ch_in))
|
| 129 |
+
|
| 130 |
+
patches = None if patch_size is None else [patch_size//(2**n) for n in range(layer_no)] # Patch Size should be 2^N
|
| 131 |
+
for n in range(layer_no):
|
| 132 |
+
downs.append(
|
| 133 |
+
Down_Block(in_ch=2*ch_in[n], hid_ch=ch_in[n], out_ch=ch_out[n], time_dim=time_dim, patch_size=None if patch_size is None else patches[n], is_last=(n==layer_no-1), num_groups=num_groups, heads=heads, dim_head=dim_head)
|
| 134 |
+
)
|
| 135 |
+
ups.append(
|
| 136 |
+
Up_Block(in_chs=(up_in[n], ch_in[-n-1]), hid_ch=up_in[n], out_ch=up_out[n], time_dim=time_dim, patch_size=None if patch_size is None else patches[layer_no-n-1], is_last=(n==0), num_groups=num_groups, heads=heads, dim_head=dim_head)
|
| 137 |
+
)
|
| 138 |
+
if n != -1:
|
| 139 |
+
conditions.append(
|
| 140 |
+
Downsample2D(dim_in=ch_in[n], dim_out=ch_out[n])
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
self.downs = nn.ModuleList(downs)
|
| 144 |
+
self.ups = nn.ModuleList(ups)
|
| 145 |
+
self.conditions = nn.ModuleList(conditions)
|
| 146 |
+
self.mid = MidBlock(in_ch=ch_out[-1], time_dim=time_dim, num_groups=num_groups, heads=heads, dim_head=dim_head)
|
| 147 |
+
# self.relative_pos = RelativePositionBias(heads=heads)
|
| 148 |
+
|
| 149 |
+
def forward(self, x, time, conds=None):
|
| 150 |
+
t = self.time_mlp(time)
|
| 151 |
+
|
| 152 |
+
hid_spatial = []
|
| 153 |
+
hid_temporal = []
|
| 154 |
+
|
| 155 |
+
# relative_position = self.relative_pos(x.shape[1], x.device) # Calculate The Relative Position
|
| 156 |
+
|
| 157 |
+
for n, down_block in enumerate(self.downs):
|
| 158 |
+
# print(x.shape)
|
| 159 |
+
# x, spatial_attn, time_attn = down_block(x, t, conds, relative_position)
|
| 160 |
+
x, spatial_attn, time_attn = down_block(x, t, conds)
|
| 161 |
+
hid_spatial.append(spatial_attn)
|
| 162 |
+
hid_temporal.append(time_attn)
|
| 163 |
+
if conds is not None:
|
| 164 |
+
conds = self.conditions[n](conds)
|
| 165 |
+
|
| 166 |
+
# out = self.mid(x, t, relative_position)
|
| 167 |
+
out = self.mid(x, t)
|
| 168 |
+
|
| 169 |
+
for up_block in self.ups:
|
| 170 |
+
# out = up_block(out, t, hid_spatial.pop(), hid_temporal.pop(), relative_position)
|
| 171 |
+
out = up_block(out, t, hid_spatial.pop(), hid_temporal.pop())
|
| 172 |
+
|
| 173 |
+
return out
|
| 174 |
+
|
| 175 |
+
# constants
|
| 176 |
+
from collections import namedtuple
|
| 177 |
+
from torch.cuda.amp import autocast
|
| 178 |
+
import torch.nn.functional as F
|
| 179 |
+
from einops import reduce
|
| 180 |
+
from tqdm.auto import tqdm
|
| 181 |
+
|
| 182 |
+
ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start'])
|
| 183 |
+
|
| 184 |
+
def identity(t, *args, **kwargs):
|
| 185 |
+
return t
|
| 186 |
+
|
| 187 |
+
def extract(a, t, x_shape):
|
| 188 |
+
b, *_ = t.shape
|
| 189 |
+
out = a.gather(-1, t)
|
| 190 |
+
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
| 191 |
+
|
| 192 |
+
def default(val, d):
|
| 193 |
+
if exists(val):
|
| 194 |
+
return val
|
| 195 |
+
return d() if callable(d) else d
|
| 196 |
+
|
| 197 |
+
def exists(x):
|
| 198 |
+
return x is not None
|
| 199 |
+
|
| 200 |
+
def guidance_scheduler(sampling_step: int, const: float):
|
| 201 |
+
return const*torch.ones(sampling_step)
|
| 202 |
+
|
| 203 |
+
class GaussianDiffusion(nn.Module):
|
| 204 |
+
def __init__(
|
| 205 |
+
self,
|
| 206 |
+
vp_model,
|
| 207 |
+
diffusion,
|
| 208 |
+
timesteps = 1000,
|
| 209 |
+
sampling_timesteps = None,
|
| 210 |
+
objective = 'pred_v',
|
| 211 |
+
beta_schedule = 'sigmoid',
|
| 212 |
+
schedule_fn_kwargs = dict(),
|
| 213 |
+
ddim_sampling_eta = 0.,
|
| 214 |
+
offset_noise_strength = 0., # https://www.crosslabs.org/blog/diffusion-with-offset-noise
|
| 215 |
+
min_snr_loss_weight = False, # https://arxiv.org/abs/2303.09556
|
| 216 |
+
min_snr_gamma = 5
|
| 217 |
+
):
|
| 218 |
+
super(GaussianDiffusion, self).__init__()
|
| 219 |
+
|
| 220 |
+
self.backbone = vp_model
|
| 221 |
+
self.diff_unet = diffusion
|
| 222 |
+
|
| 223 |
+
self.objective = objective
|
| 224 |
+
assert objective in {'pred_noise', 'pred_x0', 'pred_v'}, 'objective must be either pred_noise (predict noise) or pred_x0 (predict image start) or pred_v (predict v [v-parameterization as defined in appendix D of progressive distillation paper, used in imagen-video successfully])'
|
| 225 |
+
|
| 226 |
+
if beta_schedule == 'linear':
|
| 227 |
+
beta_schedule_fn = linear_beta_schedule
|
| 228 |
+
elif beta_schedule == 'cosine':
|
| 229 |
+
beta_schedule_fn = cosine_beta_schedule
|
| 230 |
+
elif beta_schedule == 'sigmoid':
|
| 231 |
+
beta_schedule_fn = sigmoid_beta_schedule
|
| 232 |
+
else:
|
| 233 |
+
raise ValueError(f'unknown beta schedule {beta_schedule}')
|
| 234 |
+
|
| 235 |
+
betas = beta_schedule_fn(timesteps, **schedule_fn_kwargs)
|
| 236 |
+
|
| 237 |
+
alphas = 1. - betas
|
| 238 |
+
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
| 239 |
+
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
|
| 240 |
+
|
| 241 |
+
timesteps, = betas.shape
|
| 242 |
+
self.num_timesteps = int(timesteps)
|
| 243 |
+
|
| 244 |
+
# sampling related parameters
|
| 245 |
+
|
| 246 |
+
self.sampling_timesteps = default(sampling_timesteps, timesteps) # default num sampling timesteps to number of timesteps at training
|
| 247 |
+
|
| 248 |
+
assert self.sampling_timesteps <= timesteps
|
| 249 |
+
self.is_ddim_sampling = self.sampling_timesteps < timesteps
|
| 250 |
+
self.ddim_sampling_eta = ddim_sampling_eta
|
| 251 |
+
|
| 252 |
+
# helper function to register buffer from float64 to float32
|
| 253 |
+
|
| 254 |
+
register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
|
| 255 |
+
|
| 256 |
+
register_buffer('betas', betas)
|
| 257 |
+
register_buffer('alphas_cumprod', alphas_cumprod)
|
| 258 |
+
register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
|
| 259 |
+
|
| 260 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
| 261 |
+
|
| 262 |
+
register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
|
| 263 |
+
register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
|
| 264 |
+
register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
|
| 265 |
+
register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
|
| 266 |
+
register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
|
| 267 |
+
|
| 268 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
| 269 |
+
|
| 270 |
+
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
|
| 271 |
+
|
| 272 |
+
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
| 273 |
+
|
| 274 |
+
register_buffer('posterior_variance', posterior_variance)
|
| 275 |
+
|
| 276 |
+
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
| 277 |
+
|
| 278 |
+
register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
|
| 279 |
+
register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
| 280 |
+
register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
| 281 |
+
|
| 282 |
+
# offset noise strength - in blogpost, they claimed 0.1 was ideal
|
| 283 |
+
|
| 284 |
+
self.offset_noise_strength = offset_noise_strength
|
| 285 |
+
|
| 286 |
+
# derive loss weight
|
| 287 |
+
# snr - signal noise ratio
|
| 288 |
+
|
| 289 |
+
snr = alphas_cumprod / (1 - alphas_cumprod)
|
| 290 |
+
|
| 291 |
+
# https://arxiv.org/abs/2303.09556
|
| 292 |
+
|
| 293 |
+
maybe_clipped_snr = snr.clone()
|
| 294 |
+
if min_snr_loss_weight:
|
| 295 |
+
maybe_clipped_snr.clamp_(max = min_snr_gamma)
|
| 296 |
+
|
| 297 |
+
if objective == 'pred_noise':
|
| 298 |
+
register_buffer('loss_weight', maybe_clipped_snr / snr)
|
| 299 |
+
elif objective == 'pred_x0':
|
| 300 |
+
register_buffer('loss_weight', maybe_clipped_snr)
|
| 301 |
+
elif objective == 'pred_v':
|
| 302 |
+
register_buffer('loss_weight', maybe_clipped_snr / (snr + 1))
|
| 303 |
+
|
| 304 |
+
@property
|
| 305 |
+
def device(self):
|
| 306 |
+
return self.betas.device
|
| 307 |
+
|
| 308 |
+
# CFG schdeuler => by taking pre-setting scheduler
|
| 309 |
+
def setup_guidance(self, scheduler):
|
| 310 |
+
if exists(scheduler):
|
| 311 |
+
self.CFG_sch = scheduler.to(self.device)
|
| 312 |
+
else:
|
| 313 |
+
self.CFG_sch = scheduler
|
| 314 |
+
|
| 315 |
+
def predict_start_from_noise(self, x_t, t, noise):
|
| 316 |
+
return (
|
| 317 |
+
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
|
| 318 |
+
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
def predict_noise_from_start(self, x_t, t, x0):
|
| 322 |
+
return (
|
| 323 |
+
(extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \
|
| 324 |
+
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
def predict_v(self, x_start, t, noise):
|
| 328 |
+
return (
|
| 329 |
+
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise -
|
| 330 |
+
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
def predict_start_from_v(self, x_t, t, v):
|
| 334 |
+
return (
|
| 335 |
+
extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
|
| 336 |
+
extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
def q_posterior(self, x_start, x_t, t):
|
| 340 |
+
posterior_mean = (
|
| 341 |
+
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
| 342 |
+
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
| 343 |
+
)
|
| 344 |
+
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
|
| 345 |
+
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
| 346 |
+
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
| 347 |
+
|
| 348 |
+
def model_predictions(self, x, t, cond, clip_x_start = False, rederive_pred_noise = False):
|
| 349 |
+
# print(t.device)
|
| 350 |
+
if exists(self.CFG_sch):
|
| 351 |
+
uncond = self.diff_unet(x, t, conds=None) #conds=torch.zeros_like(cond))
|
| 352 |
+
model_output = self.diff_unet(x, t, conds=cond)
|
| 353 |
+
time = int(t[0])
|
| 354 |
+
model_output = model_output - self.CFG_sch[time] * (uncond - model_output)
|
| 355 |
+
else:
|
| 356 |
+
model_output = self.diff_unet(x, t, conds=cond)
|
| 357 |
+
maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity
|
| 358 |
+
|
| 359 |
+
if self.objective == 'pred_noise':
|
| 360 |
+
pred_noise = model_output
|
| 361 |
+
x_start = self.predict_start_from_noise(x, t, pred_noise)
|
| 362 |
+
x_start = maybe_clip(x_start)
|
| 363 |
+
|
| 364 |
+
if clip_x_start and rederive_pred_noise:
|
| 365 |
+
pred_noise = self.predict_noise_from_start(x, t, x_start)
|
| 366 |
+
|
| 367 |
+
elif self.objective == 'pred_x0':
|
| 368 |
+
x_start = model_output
|
| 369 |
+
x_start = maybe_clip(x_start)
|
| 370 |
+
pred_noise = self.predict_noise_from_start(x, t, x_start)
|
| 371 |
+
|
| 372 |
+
elif self.objective == 'pred_v':
|
| 373 |
+
v = model_output
|
| 374 |
+
x_start = self.predict_start_from_v(x, t, v)
|
| 375 |
+
x_start = maybe_clip(x_start)
|
| 376 |
+
pred_noise = self.predict_noise_from_start(x, t, x_start)
|
| 377 |
+
|
| 378 |
+
return ModelPrediction(pred_noise, x_start)
|
| 379 |
+
|
| 380 |
+
def p_mean_variance(self, x, t, cond=None, clip_denoised = True):
|
| 381 |
+
preds = self.model_predictions(x, t, cond=cond, clip_x_start=False,)
|
| 382 |
+
x_start = preds.pred_x_start
|
| 383 |
+
|
| 384 |
+
if clip_denoised:
|
| 385 |
+
x_start.clamp_(-1., 1.)
|
| 386 |
+
|
| 387 |
+
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t)
|
| 388 |
+
return model_mean, posterior_variance, posterior_log_variance, x_start
|
| 389 |
+
|
| 390 |
+
@torch.no_grad()
|
| 391 |
+
def p_sample(self, x, t: int, cond=None):
|
| 392 |
+
b, *_, device = *x.shape, self.device
|
| 393 |
+
batched_times = torch.full((b,), t, device = device, dtype = torch.long)
|
| 394 |
+
model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, cond=cond, clip_denoised = False)
|
| 395 |
+
noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0
|
| 396 |
+
pred_img = model_mean + (0.5 * model_log_variance).exp() * noise
|
| 397 |
+
return pred_img, x_start
|
| 398 |
+
|
| 399 |
+
@torch.no_grad()
|
| 400 |
+
def p_sample_loop(self, shape, cond=None, return_all_timesteps = False):
|
| 401 |
+
batch, device = shape[0], self.device
|
| 402 |
+
|
| 403 |
+
frames_pred = torch.randn(shape, device = device)
|
| 404 |
+
imgs = [frames_pred]
|
| 405 |
+
|
| 406 |
+
for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps, disable=True):
|
| 407 |
+
frames_pred, _ = self.p_sample(frames_pred, t, cond=cond)
|
| 408 |
+
imgs.append(frames_pred)
|
| 409 |
+
|
| 410 |
+
ret = frames_pred if not return_all_timesteps else torch.stack(imgs, dim = 1)
|
| 411 |
+
return ret
|
| 412 |
+
|
| 413 |
+
@torch.no_grad()
|
| 414 |
+
def ddim_sample(self, shape, cond=None, return_all_timesteps = False):
|
| 415 |
+
batch, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective
|
| 416 |
+
device = self.device
|
| 417 |
+
times = torch.linspace(-1, total_timesteps - 1, steps = sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
|
| 418 |
+
times = list(reversed(times.int().tolist()))
|
| 419 |
+
time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
|
| 420 |
+
|
| 421 |
+
frames_pred = torch.randn(shape, device = device)
|
| 422 |
+
imgs = [frames_pred]
|
| 423 |
+
|
| 424 |
+
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step', disable=True):
|
| 425 |
+
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
|
| 426 |
+
pred_noise, x_start, *_ = self.model_predictions(
|
| 427 |
+
frames_pred,
|
| 428 |
+
time_cond,
|
| 429 |
+
cond = cond, #cond.copy(),
|
| 430 |
+
clip_x_start = False,
|
| 431 |
+
rederive_pred_noise = True
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
if time_next < 0:
|
| 435 |
+
frames_pred = x_start
|
| 436 |
+
imgs.append(frames_pred)
|
| 437 |
+
continue
|
| 438 |
+
|
| 439 |
+
alpha = self.alphas_cumprod[time]
|
| 440 |
+
alpha_next = self.alphas_cumprod[time_next]
|
| 441 |
+
|
| 442 |
+
sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
|
| 443 |
+
c = (1 - alpha_next - sigma ** 2).sqrt()
|
| 444 |
+
|
| 445 |
+
noise = torch.randn_like(frames_pred)
|
| 446 |
+
|
| 447 |
+
frames_pred = x_start * alpha_next.sqrt() + \
|
| 448 |
+
c * pred_noise + \
|
| 449 |
+
sigma * noise
|
| 450 |
+
|
| 451 |
+
imgs.append(frames_pred)
|
| 452 |
+
|
| 453 |
+
ret = frames_pred if not return_all_timesteps else torch.stack(imgs, dim = 1)
|
| 454 |
+
return ret
|
| 455 |
+
|
| 456 |
+
@torch.no_grad()
|
| 457 |
+
def sample(self, frames_in, return_all_timesteps = False):
|
| 458 |
+
assert frames_in.ndim == 5
|
| 459 |
+
B, T_in, C, H, W = frames_in.shape
|
| 460 |
+
device = self.device
|
| 461 |
+
|
| 462 |
+
backbone_output, conds, *_ = self.backbone(frames_in)
|
| 463 |
+
sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample
|
| 464 |
+
|
| 465 |
+
*_, c, h, w = conds.shape
|
| 466 |
+
tgt_shape = conds.reshape(B, -1, c, h, w).shape
|
| 467 |
+
ldm_pred = sample_fn(
|
| 468 |
+
tgt_shape,
|
| 469 |
+
cond=conds,
|
| 470 |
+
return_all_timesteps = return_all_timesteps
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
ldm_pred = rearrange(ldm_pred, 'b t c h w -> (b t) c h w')
|
| 474 |
+
frames_pred = self.backbone.vae.decode(ldm_pred)
|
| 475 |
+
frames_pred = rearrange(frames_pred, '(b t) c h w -> b t c h w', b=B)
|
| 476 |
+
return frames_pred, backbone_output
|
| 477 |
+
|
| 478 |
+
def predict(self, frames_in, compute_loss=False, **kwargs):
|
| 479 |
+
pred, mu = self.sample(frames_in=frames_in)
|
| 480 |
+
return pred, mu
|
| 481 |
+
|
| 482 |
+
def compute_loss(self, frames_in, frames_gt, validate=False):
|
| 483 |
+
compute_loss = True and (not validate)
|
| 484 |
+
B, T_in, C, H, W = frames_in.shape
|
| 485 |
+
T_out = frames_gt.shape[1]
|
| 486 |
+
device = frames_in.device
|
| 487 |
+
|
| 488 |
+
"""
|
| 489 |
+
Diffusion Loss
|
| 490 |
+
"""
|
| 491 |
+
backbone_output, conds = self.backbone(frames_in)
|
| 492 |
+
hid_gt, _ = self.backbone.vae.encode(
|
| 493 |
+
rearrange(frames_gt, 'b t c h w -> (b t) c h w')
|
| 494 |
+
)
|
| 495 |
+
hid_gt = rearrange(hid_gt, '(b t) c h w -> b t c h w', b=B)
|
| 496 |
+
t = torch.randint(0, self.num_timesteps, (B,), device=self.device).long()
|
| 497 |
+
if random.random() > 0.85: # Unconditional
|
| 498 |
+
conds = None
|
| 499 |
+
diff_loss = self.p_losses(hid_gt.detach(), t, cond=conds)
|
| 500 |
+
|
| 501 |
+
"""
|
| 502 |
+
Backbone Loss
|
| 503 |
+
"""
|
| 504 |
+
mu_loss = self.backbone._losses_(frames_in, frames_gt)
|
| 505 |
+
|
| 506 |
+
"""
|
| 507 |
+
VAE Loss
|
| 508 |
+
"""
|
| 509 |
+
ae_loss, kl_loss = self.backbone.vae._losses_(
|
| 510 |
+
rearrange(torch.cat((frames_in, frames_gt), dim=1), 'b t c h w -> (b t) c h w'),
|
| 511 |
+
rearrange(torch.cat((frames_in, frames_gt), dim=1), 'b t c h w -> (b t) c h w')
|
| 512 |
+
)
|
| 513 |
+
kl_weight = 1E-6
|
| 514 |
+
recon_loss = ae_loss + kl_weight*kl_loss
|
| 515 |
+
|
| 516 |
+
"""
|
| 517 |
+
Prior Loss at t=T [Noisy]
|
| 518 |
+
"""
|
| 519 |
+
hid_gt, _ = self.backbone.vae.encode(
|
| 520 |
+
rearrange(frames_gt, 'b t c h w -> (b t) c h w')
|
| 521 |
+
)
|
| 522 |
+
hid_gt = rearrange(hid_gt, '(b t) c h w -> b t c h w', b=B)
|
| 523 |
+
T = torch.ones((B,), device=self.device).long() * (self.num_timesteps - 1)
|
| 524 |
+
mu_noisy = extract(self.sqrt_alphas_cumprod, T, hid_gt.shape) * hid_gt
|
| 525 |
+
sigma_noisy = extract(self.sqrt_one_minus_alphas_cumprod, T, hid_gt.shape)
|
| 526 |
+
log_var_noisy = 2*torch.log(sigma_noisy)
|
| 527 |
+
prior_loss = self.kl_from_standard_normal(mu_noisy, log_var_noisy)
|
| 528 |
+
|
| 529 |
+
return recon_loss, mu_loss, diff_loss, prior_loss
|
| 530 |
+
|
| 531 |
+
|
| 532 |
+
def kl_from_standard_normal(self, mean, log_var):
|
| 533 |
+
kl = 0.5 * (log_var.exp() + mean.square() - 1.0 - log_var)
|
| 534 |
+
return kl.mean()
|
| 535 |
+
|
| 536 |
+
@autocast(enabled = False)
|
| 537 |
+
def q_sample(self, x_start, t, noise = None):
|
| 538 |
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
| 539 |
+
|
| 540 |
+
return (
|
| 541 |
+
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
| 542 |
+
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
| 543 |
+
)
|
| 544 |
+
|
| 545 |
+
def p_losses(self, x_start, t, noise=None, offset_noise_strength=None, cond=None):
|
| 546 |
+
b, T, c, h, w = x_start.shape
|
| 547 |
+
|
| 548 |
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
| 549 |
+
|
| 550 |
+
# offset noise - https://www.crosslabs.org/blog/diffusion-with-offset-noise
|
| 551 |
+
offset_noise_strength = default(offset_noise_strength, self.offset_noise_strength)
|
| 552 |
+
|
| 553 |
+
if offset_noise_strength > 0.:
|
| 554 |
+
offset_noise = torch.randn(x_start.shape[:2], device = self.device)
|
| 555 |
+
noise += offset_noise_strength * rearrange(offset_noise, 'b c -> b c 1 1')
|
| 556 |
+
|
| 557 |
+
# noise sample
|
| 558 |
+
x = self.q_sample(x_start=x_start, t=t, noise=noise) # Use q_sample here for updating: https://github.com/lucidrains/denoising-diffusion-pytorch/blob/main/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L763
|
| 559 |
+
|
| 560 |
+
model_out = self.diff_unet(x, t, conds=cond)
|
| 561 |
+
|
| 562 |
+
if self.objective == 'pred_noise':
|
| 563 |
+
target = noise
|
| 564 |
+
elif self.objective == 'pred_x0':
|
| 565 |
+
target = x_start
|
| 566 |
+
elif self.objective == 'pred_v':
|
| 567 |
+
v = self.predict_v(x_start, t, noise)
|
| 568 |
+
target = v
|
| 569 |
+
else:
|
| 570 |
+
raise ValueError(f'unknown objective {self.objective}')
|
| 571 |
+
|
| 572 |
+
loss = F.mse_loss(model_out, target, reduction = 'none') # (B, T, C, H, W)
|
| 573 |
+
loss = reduce(loss, 'b ... -> b', 'mean')
|
| 574 |
+
|
| 575 |
+
loss = loss * extract(self.loss_weight, t, loss.shape)
|
| 576 |
+
return loss.mean()
|
| 577 |
+
|
| 578 |
+
@torch.no_grad()
|
| 579 |
+
def forward(self, input_x, include_mu=False, **kwargs):
|
| 580 |
+
pred, mu = self.predict(input_x, compute_loss=False)
|
| 581 |
+
if include_mu:
|
| 582 |
+
return pred, mu
|
| 583 |
+
else:
|
| 584 |
+
return pred
|
| 585 |
+
|
| 586 |
+
from stldm.modules import SimVPV2_Model, VAE
|
| 587 |
+
def model_setup(model_config, print_info=False, cfg_str=None):
|
| 588 |
+
if print_info:
|
| 589 |
+
print('Setup the model with considering temporal attention be (BHW, T, C) ... ...')
|
| 590 |
+
print('Train it from end to end')
|
| 591 |
+
vp_config = model_config['vp_param']
|
| 592 |
+
ldm_config = model_config['stldm_param']
|
| 593 |
+
|
| 594 |
+
vpm = SimVPV2_Model(**vp_config)
|
| 595 |
+
ldm = LDM(**ldm_config)
|
| 596 |
+
model = GaussianDiffusion(vp_model=vpm, diffusion=ldm, **model_config['param'])
|
| 597 |
+
|
| 598 |
+
scheduler = guidance_scheduler(sampling_step=model_config['param']['timesteps'], const=cfg_str) if cfg_str is not None else None
|
| 599 |
+
model.setup_guidance(scheduler)
|
| 600 |
+
|
| 601 |
+
return model
|
| 602 |
+
|
| 603 |
+
def ae_setup(model_config):
|
| 604 |
+
vp_config = model_config['vp_param']
|
| 605 |
+
vpm = SimVPV2_Model(**vp_config)
|
| 606 |
+
ae = vpm.vae
|
| 607 |
+
return ae
|
| 608 |
+
|
| 609 |
+
def backbone_setup(model_config):
|
| 610 |
+
vp_config = model_config['vp_param']
|
| 611 |
+
vpm = SimVPV2_Model(**vp_config)
|
| 612 |
+
return vpm
|
stldm/stldm_hf.py
ADDED
|
@@ -0,0 +1,620 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch, random
|
| 2 |
+
from torch import nn
|
| 3 |
+
from einops import rearrange
|
| 4 |
+
|
| 5 |
+
from stldm.submodules import *
|
| 6 |
+
|
| 7 |
+
class Down_Block(nn.Module):
|
| 8 |
+
def __init__(self, in_ch, hid_ch, out_ch, time_dim, is_last, patch_size=None, num_groups=8, heads=4, dim_head=32):
|
| 9 |
+
super(Down_Block, self).__init__()
|
| 10 |
+
self.block1 = ResnetBlock(dim=in_ch, dim_out=hid_ch, time_emb_dim=time_dim, groups=num_groups)
|
| 11 |
+
self.attn_spatial = Residual(PreNorm(hid_ch, Quadratic_SpatialAttention(dim=hid_ch, heads=heads, dim_head=dim_head))) if patch_size is None else Residual(PreNorm(hid_ch, Linear_SpatialAttention(dim=hid_ch, patch_size=patch_size, heads=heads, dim_head=dim_head)))
|
| 12 |
+
self.block2 = ResnetBlock(dim=hid_ch, dim_out=hid_ch, groups=num_groups)
|
| 13 |
+
# self.attn_temporal = Residual(PreNorm(hid_ch, TemporalAttention_Pos(dim=hid_ch, heads=heads, dim_head=dim_head)))
|
| 14 |
+
self.attn_temporal = Residual(PreNorm(hid_ch, TemporalAttention(dim=hid_ch, heads=heads, dim_head=dim_head)))
|
| 15 |
+
self.last = Downsample2D(dim_in=hid_ch, dim_out=out_ch) if not is_last else ChannelConversion(hid_ch, out_ch)
|
| 16 |
+
|
| 17 |
+
def forward(self, x, time_emb, cond=None, relative_pos=None):
|
| 18 |
+
assert x.ndim==5
|
| 19 |
+
B, T, C, H, W = x.shape
|
| 20 |
+
|
| 21 |
+
x = x.reshape(B*T, C, H, W)
|
| 22 |
+
if cond is None:
|
| 23 |
+
cond = torch.zeros_like(x) # -> Unconditioning
|
| 24 |
+
|
| 25 |
+
time_emb = time_emb.unsqueeze(1) # From (B C) to (B 1 C)
|
| 26 |
+
time_emb = time_emb.repeat(1, T, 1)
|
| 27 |
+
time_emb = time_emb.reshape(B*T, -1)
|
| 28 |
+
|
| 29 |
+
out = torch.cat((x, cond), dim=1) # BT, 2C, H, W
|
| 30 |
+
out = self.block1(out, time_emb)
|
| 31 |
+
|
| 32 |
+
spatial_attn = self.attn_spatial(out)
|
| 33 |
+
out = self.block2(spatial_attn, time_emb)
|
| 34 |
+
*_, c, h, w = out.shape
|
| 35 |
+
out = out.reshape(B,T,c,h,w)
|
| 36 |
+
|
| 37 |
+
# temporal_attn = self.attn_temporal(out, relative_pos)
|
| 38 |
+
temporal_attn = self.attn_temporal(out)
|
| 39 |
+
temporal_attn = temporal_attn.reshape(B*T,c,h,w)
|
| 40 |
+
|
| 41 |
+
out = self.last(temporal_attn)
|
| 42 |
+
*_, c, h, w = out.shape
|
| 43 |
+
|
| 44 |
+
return out.reshape(B, T, c, h, w), spatial_attn, temporal_attn
|
| 45 |
+
|
| 46 |
+
class MidBlock(nn.Module):
|
| 47 |
+
def __init__(self, in_ch, time_dim, num_groups=8, heads=4, dim_head=32):
|
| 48 |
+
super(MidBlock, self).__init__()
|
| 49 |
+
self.block1 = ResnetBlock(dim=in_ch, dim_out=in_ch, time_emb_dim=time_dim, groups=num_groups)
|
| 50 |
+
self.qattn_spatial = Residual(PreNorm(in_ch, Quadratic_SpatialAttention(dim=in_ch, heads=heads, dim_head=dim_head)))
|
| 51 |
+
self.block2 = ResnetBlock(dim=in_ch, dim_out=in_ch, time_emb_dim=time_dim, groups=num_groups)
|
| 52 |
+
# self.qattn_time = Residual(PreNorm(in_ch, TemporalAttention_Pos(dim=in_ch, heads=heads, dim_head=dim_head)))
|
| 53 |
+
self.qattn_time = Residual(PreNorm(in_ch, TemporalAttention(dim=in_ch, heads=heads, dim_head=dim_head)))
|
| 54 |
+
self.block3 = ResnetBlock(dim=in_ch, dim_out=in_ch, time_emb_dim=time_dim, groups=num_groups)
|
| 55 |
+
|
| 56 |
+
def forward(self, x, time_emb, relative_pos=None):
|
| 57 |
+
assert x.ndim==5
|
| 58 |
+
B, T, C, H, W = x.shape
|
| 59 |
+
x = x.reshape(B*T, C, H, W)
|
| 60 |
+
|
| 61 |
+
time_emb = time_emb.unsqueeze(1) # From (B C) to (B 1 C)
|
| 62 |
+
time_emb = time_emb.repeat(1, T, 1)
|
| 63 |
+
time_emb = time_emb.reshape(B*T, -1)
|
| 64 |
+
|
| 65 |
+
out = self.block1(x, time_emb)
|
| 66 |
+
out = self.qattn_spatial(out)
|
| 67 |
+
out = self.block2(out, time_emb) # a little bit difference here
|
| 68 |
+
|
| 69 |
+
out = out.reshape((B, T, C, H, W))
|
| 70 |
+
# out = self.qattn_time(out, relative_pos).reshape(B*T, C, H, W)
|
| 71 |
+
out = self.qattn_time(out).reshape(B*T, C, H, W)
|
| 72 |
+
out = self.block3(out, time_emb)
|
| 73 |
+
|
| 74 |
+
*_, c, h, w = out.shape
|
| 75 |
+
return out.reshape(B, T, c, h, w)
|
| 76 |
+
|
| 77 |
+
class Up_Block(nn.Module):
|
| 78 |
+
def __init__(self, in_chs, hid_ch, out_ch, is_last, time_dim, patch_size=None, num_groups=8, heads=4, dim_head=32):
|
| 79 |
+
super(Up_Block, self).__init__()
|
| 80 |
+
in_ch, skip_ch = in_chs
|
| 81 |
+
self.up = Upsample2D(dim_in=in_ch, dim_out=hid_ch) if not is_last else ChannelConversion(in_ch, hid_ch)
|
| 82 |
+
self.attn_spatial = Residual(PreNorm(hid_ch, Quadratic_SpatialAttention(dim=hid_ch, heads=heads, dim_head=dim_head) if patch_size is None else Linear_SpatialAttention(dim=hid_ch, patch_size=patch_size, heads=heads, dim_head=dim_head)))
|
| 83 |
+
self.block1 = ResnetBlock(dim=hid_ch+skip_ch, dim_out=hid_ch, time_emb_dim=time_dim, groups=num_groups)
|
| 84 |
+
# self.attn_temporal = Residual(PreNorm(hid_ch, TemporalAttention_Pos(dim=hid_ch, heads=heads, dim_head=dim_head)))
|
| 85 |
+
self.attn_temporal = Residual(PreNorm(hid_ch, TemporalAttention(dim=hid_ch, heads=heads, dim_head=dim_head)))
|
| 86 |
+
self.block2 = ResnetBlock(dim=hid_ch+skip_ch, dim_out=out_ch, time_emb_dim=time_dim, groups=num_groups)
|
| 87 |
+
|
| 88 |
+
def forward(self, x, time_emb, spatialattn_skip, tempattn_skip, relative_pos=None):
|
| 89 |
+
assert x.ndim==5
|
| 90 |
+
B, T, C, H, W = x.shape
|
| 91 |
+
x = x.reshape(B*T, C, H, W)
|
| 92 |
+
|
| 93 |
+
time_emb = time_emb.unsqueeze(1) # From (B C) to (B 1 C)
|
| 94 |
+
time_emb = time_emb.repeat(1, T, 1)
|
| 95 |
+
time_emb = time_emb.reshape(B*T, -1)
|
| 96 |
+
|
| 97 |
+
out = self.up(x)
|
| 98 |
+
*_, c, h, w = out.shape
|
| 99 |
+
out = out.reshape(-1, T, c, h, w)
|
| 100 |
+
|
| 101 |
+
# out = self.attn_temporal(out, relative_pos).reshape(B*T, c, h, w)
|
| 102 |
+
out = self.attn_temporal(out).reshape(B*T, c, h, w)
|
| 103 |
+
|
| 104 |
+
out = torch.cat((out, tempattn_skip), dim=1)
|
| 105 |
+
out = self.block1(out, time_emb)
|
| 106 |
+
|
| 107 |
+
out = self.attn_spatial(out)
|
| 108 |
+
|
| 109 |
+
out = torch.cat((out, spatialattn_skip), dim=1)
|
| 110 |
+
out = self.block2(out, time_emb)
|
| 111 |
+
*_, c, h, w = out.shape
|
| 112 |
+
return out.reshape(B, T, c, h, w)
|
| 113 |
+
|
| 114 |
+
class LDM(nn.Module):
|
| 115 |
+
def __init__(self, in_ch, chs_mult:tuple, patch_size=None, num_groups=8, heads=4, dim_head=32, base_ch=64):
|
| 116 |
+
super(LDM, self).__init__()
|
| 117 |
+
# Time Embedding MLP
|
| 118 |
+
time_dim = 4*base_ch
|
| 119 |
+
fourier_dim = base_ch
|
| 120 |
+
self.time_mlp = Time_MLP(dim=base_ch, time_dim=time_dim, fourier_dim=fourier_dim)
|
| 121 |
+
|
| 122 |
+
ups, downs = [], []
|
| 123 |
+
conditions = []
|
| 124 |
+
|
| 125 |
+
layer_no = len(chs_mult)
|
| 126 |
+
chs = [in_ch, *map(lambda m: base_ch*m, chs_mult)]
|
| 127 |
+
ch_in, ch_out = chs[:-1], chs[1:]
|
| 128 |
+
up_in, up_out = list(reversed(ch_out)), list(reversed(ch_in))
|
| 129 |
+
|
| 130 |
+
patches = None if patch_size is None else [patch_size//(2**n) for n in range(layer_no)] # Patch Size should be 2^N
|
| 131 |
+
for n in range(layer_no):
|
| 132 |
+
downs.append(
|
| 133 |
+
Down_Block(in_ch=2*ch_in[n], hid_ch=ch_in[n], out_ch=ch_out[n], time_dim=time_dim, patch_size=None if patch_size is None else patches[n], is_last=(n==layer_no-1), num_groups=num_groups, heads=heads, dim_head=dim_head)
|
| 134 |
+
)
|
| 135 |
+
ups.append(
|
| 136 |
+
Up_Block(in_chs=(up_in[n], ch_in[-n-1]), hid_ch=up_in[n], out_ch=up_out[n], time_dim=time_dim, patch_size=None if patch_size is None else patches[layer_no-n-1], is_last=(n==0), num_groups=num_groups, heads=heads, dim_head=dim_head)
|
| 137 |
+
)
|
| 138 |
+
if n != -1:
|
| 139 |
+
conditions.append(
|
| 140 |
+
Downsample2D(dim_in=ch_in[n], dim_out=ch_out[n])
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
self.downs = nn.ModuleList(downs)
|
| 144 |
+
self.ups = nn.ModuleList(ups)
|
| 145 |
+
self.conditions = nn.ModuleList(conditions)
|
| 146 |
+
self.mid = MidBlock(in_ch=ch_out[-1], time_dim=time_dim, num_groups=num_groups, heads=heads, dim_head=dim_head)
|
| 147 |
+
# self.relative_pos = RelativePositionBias(heads=heads)
|
| 148 |
+
|
| 149 |
+
def forward(self, x, time, conds=None):
|
| 150 |
+
t = self.time_mlp(time)
|
| 151 |
+
|
| 152 |
+
hid_spatial = []
|
| 153 |
+
hid_temporal = []
|
| 154 |
+
|
| 155 |
+
# relative_position = self.relative_pos(x.shape[1], x.device) # Calculate The Relative Position
|
| 156 |
+
|
| 157 |
+
for n, down_block in enumerate(self.downs):
|
| 158 |
+
# print(x.shape)
|
| 159 |
+
# x, spatial_attn, time_attn = down_block(x, t, conds, relative_position)
|
| 160 |
+
x, spatial_attn, time_attn = down_block(x, t, conds)
|
| 161 |
+
hid_spatial.append(spatial_attn)
|
| 162 |
+
hid_temporal.append(time_attn)
|
| 163 |
+
if conds is not None:
|
| 164 |
+
conds = self.conditions[n](conds)
|
| 165 |
+
|
| 166 |
+
# out = self.mid(x, t, relative_position)
|
| 167 |
+
out = self.mid(x, t)
|
| 168 |
+
|
| 169 |
+
for up_block in self.ups:
|
| 170 |
+
# out = up_block(out, t, hid_spatial.pop(), hid_temporal.pop(), relative_position)
|
| 171 |
+
out = up_block(out, t, hid_spatial.pop(), hid_temporal.pop())
|
| 172 |
+
|
| 173 |
+
return out
|
| 174 |
+
|
| 175 |
+
# constants
|
| 176 |
+
from collections import namedtuple
|
| 177 |
+
from torch.cuda.amp import autocast
|
| 178 |
+
import torch.nn.functional as F
|
| 179 |
+
from einops import reduce
|
| 180 |
+
from tqdm.auto import tqdm
|
| 181 |
+
|
| 182 |
+
ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start'])
|
| 183 |
+
|
| 184 |
+
def identity(t, *args, **kwargs):
|
| 185 |
+
return t
|
| 186 |
+
|
| 187 |
+
def extract(a, t, x_shape):
|
| 188 |
+
b, *_ = t.shape
|
| 189 |
+
out = a.gather(-1, t)
|
| 190 |
+
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
| 191 |
+
|
| 192 |
+
def default(val, d):
|
| 193 |
+
if exists(val):
|
| 194 |
+
return val
|
| 195 |
+
return d() if callable(d) else d
|
| 196 |
+
|
| 197 |
+
def exists(x):
|
| 198 |
+
return x is not None
|
| 199 |
+
|
| 200 |
+
def guidance_scheduler(sampling_step: int, const: float):
|
| 201 |
+
return const*torch.ones(sampling_step)
|
| 202 |
+
|
| 203 |
+
from huggingface_hub import PyTorchModelHubMixin
|
| 204 |
+
|
| 205 |
+
class GaussianDiffusion(
|
| 206 |
+
nn.Module,
|
| 207 |
+
PyTorchModelHubMixin,
|
| 208 |
+
# optionally, you can add metadata which gets pushed to the model card
|
| 209 |
+
repo_url="https://github.com/sqfoo/stldm_official",
|
| 210 |
+
pipeline_tag="Precipitation_Nowcasting",
|
| 211 |
+
license="mit"):
|
| 212 |
+
def __init__(
|
| 213 |
+
self,
|
| 214 |
+
vp_param: dict,
|
| 215 |
+
stldm_param: dict,
|
| 216 |
+
timesteps = 1000,
|
| 217 |
+
sampling_timesteps = None,
|
| 218 |
+
objective = 'pred_v',
|
| 219 |
+
beta_schedule = 'sigmoid',
|
| 220 |
+
schedule_fn_kwargs = dict(),
|
| 221 |
+
ddim_sampling_eta = 0.,
|
| 222 |
+
offset_noise_strength = 0., # https://www.crosslabs.org/blog/diffusion-with-offset-noise
|
| 223 |
+
min_snr_loss_weight = False, # https://arxiv.org/abs/2303.09556
|
| 224 |
+
min_snr_gamma = 5
|
| 225 |
+
):
|
| 226 |
+
super(GaussianDiffusion, self).__init__()
|
| 227 |
+
|
| 228 |
+
self.backbone = SimVPV2_Model(**vp_param)
|
| 229 |
+
self.diff_unet = LDM(**stldm_param)
|
| 230 |
+
|
| 231 |
+
self.objective = objective
|
| 232 |
+
assert objective in {'pred_noise', 'pred_x0', 'pred_v'}, 'objective must be either pred_noise (predict noise) or pred_x0 (predict image start) or pred_v (predict v [v-parameterization as defined in appendix D of progressive distillation paper, used in imagen-video successfully])'
|
| 233 |
+
|
| 234 |
+
if beta_schedule == 'linear':
|
| 235 |
+
beta_schedule_fn = linear_beta_schedule
|
| 236 |
+
elif beta_schedule == 'cosine':
|
| 237 |
+
beta_schedule_fn = cosine_beta_schedule
|
| 238 |
+
elif beta_schedule == 'sigmoid':
|
| 239 |
+
beta_schedule_fn = sigmoid_beta_schedule
|
| 240 |
+
else:
|
| 241 |
+
raise ValueError(f'unknown beta schedule {beta_schedule}')
|
| 242 |
+
|
| 243 |
+
betas = beta_schedule_fn(timesteps, **schedule_fn_kwargs)
|
| 244 |
+
|
| 245 |
+
alphas = 1. - betas
|
| 246 |
+
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
| 247 |
+
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
|
| 248 |
+
|
| 249 |
+
timesteps, = betas.shape
|
| 250 |
+
self.num_timesteps = int(timesteps)
|
| 251 |
+
|
| 252 |
+
# sampling related parameters
|
| 253 |
+
|
| 254 |
+
self.sampling_timesteps = default(sampling_timesteps, timesteps) # default num sampling timesteps to number of timesteps at training
|
| 255 |
+
|
| 256 |
+
assert self.sampling_timesteps <= timesteps
|
| 257 |
+
self.is_ddim_sampling = self.sampling_timesteps < timesteps
|
| 258 |
+
self.ddim_sampling_eta = ddim_sampling_eta
|
| 259 |
+
|
| 260 |
+
# helper function to register buffer from float64 to float32
|
| 261 |
+
|
| 262 |
+
register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
|
| 263 |
+
|
| 264 |
+
register_buffer('betas', betas)
|
| 265 |
+
register_buffer('alphas_cumprod', alphas_cumprod)
|
| 266 |
+
register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
|
| 267 |
+
|
| 268 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
| 269 |
+
|
| 270 |
+
register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
|
| 271 |
+
register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
|
| 272 |
+
register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
|
| 273 |
+
register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
|
| 274 |
+
register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
|
| 275 |
+
|
| 276 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
| 277 |
+
|
| 278 |
+
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
|
| 279 |
+
|
| 280 |
+
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
| 281 |
+
|
| 282 |
+
register_buffer('posterior_variance', posterior_variance)
|
| 283 |
+
|
| 284 |
+
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
| 285 |
+
|
| 286 |
+
register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
|
| 287 |
+
register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
| 288 |
+
register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
| 289 |
+
|
| 290 |
+
# offset noise strength - in blogpost, they claimed 0.1 was ideal
|
| 291 |
+
|
| 292 |
+
self.offset_noise_strength = offset_noise_strength
|
| 293 |
+
|
| 294 |
+
# derive loss weight
|
| 295 |
+
# snr - signal noise ratio
|
| 296 |
+
|
| 297 |
+
snr = alphas_cumprod / (1 - alphas_cumprod)
|
| 298 |
+
|
| 299 |
+
# https://arxiv.org/abs/2303.09556
|
| 300 |
+
|
| 301 |
+
maybe_clipped_snr = snr.clone()
|
| 302 |
+
if min_snr_loss_weight:
|
| 303 |
+
maybe_clipped_snr.clamp_(max = min_snr_gamma)
|
| 304 |
+
|
| 305 |
+
if objective == 'pred_noise':
|
| 306 |
+
register_buffer('loss_weight', maybe_clipped_snr / snr)
|
| 307 |
+
elif objective == 'pred_x0':
|
| 308 |
+
register_buffer('loss_weight', maybe_clipped_snr)
|
| 309 |
+
elif objective == 'pred_v':
|
| 310 |
+
register_buffer('loss_weight', maybe_clipped_snr / (snr + 1))
|
| 311 |
+
|
| 312 |
+
@property
|
| 313 |
+
def device(self):
|
| 314 |
+
return self.betas.device
|
| 315 |
+
|
| 316 |
+
# CFG schdeuler => by taking pre-setting scheduler
|
| 317 |
+
def setup_guidance(self, scheduler):
|
| 318 |
+
if exists(scheduler):
|
| 319 |
+
self.CFG_sch = scheduler.to(self.device)
|
| 320 |
+
else:
|
| 321 |
+
self.CFG_sch = scheduler
|
| 322 |
+
|
| 323 |
+
def predict_start_from_noise(self, x_t, t, noise):
|
| 324 |
+
return (
|
| 325 |
+
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
|
| 326 |
+
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
def predict_noise_from_start(self, x_t, t, x0):
|
| 330 |
+
return (
|
| 331 |
+
(extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \
|
| 332 |
+
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
def predict_v(self, x_start, t, noise):
|
| 336 |
+
return (
|
| 337 |
+
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise -
|
| 338 |
+
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
def predict_start_from_v(self, x_t, t, v):
|
| 342 |
+
return (
|
| 343 |
+
extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
|
| 344 |
+
extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
def q_posterior(self, x_start, x_t, t):
|
| 348 |
+
posterior_mean = (
|
| 349 |
+
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
| 350 |
+
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
| 351 |
+
)
|
| 352 |
+
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
|
| 353 |
+
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
| 354 |
+
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
| 355 |
+
|
| 356 |
+
def model_predictions(self, x, t, cond, clip_x_start = False, rederive_pred_noise = False):
|
| 357 |
+
# print(t.device)
|
| 358 |
+
if exists(self.CFG_sch):
|
| 359 |
+
uncond = self.diff_unet(x, t, conds=None) #conds=torch.zeros_like(cond))
|
| 360 |
+
model_output = self.diff_unet(x, t, conds=cond)
|
| 361 |
+
time = int(t[0])
|
| 362 |
+
model_output = model_output - self.CFG_sch[time] * (uncond - model_output)
|
| 363 |
+
else:
|
| 364 |
+
model_output = self.diff_unet(x, t, conds=cond)
|
| 365 |
+
maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity
|
| 366 |
+
|
| 367 |
+
if self.objective == 'pred_noise':
|
| 368 |
+
pred_noise = model_output
|
| 369 |
+
x_start = self.predict_start_from_noise(x, t, pred_noise)
|
| 370 |
+
x_start = maybe_clip(x_start)
|
| 371 |
+
|
| 372 |
+
if clip_x_start and rederive_pred_noise:
|
| 373 |
+
pred_noise = self.predict_noise_from_start(x, t, x_start)
|
| 374 |
+
|
| 375 |
+
elif self.objective == 'pred_x0':
|
| 376 |
+
x_start = model_output
|
| 377 |
+
x_start = maybe_clip(x_start)
|
| 378 |
+
pred_noise = self.predict_noise_from_start(x, t, x_start)
|
| 379 |
+
|
| 380 |
+
elif self.objective == 'pred_v':
|
| 381 |
+
v = model_output
|
| 382 |
+
x_start = self.predict_start_from_v(x, t, v)
|
| 383 |
+
x_start = maybe_clip(x_start)
|
| 384 |
+
pred_noise = self.predict_noise_from_start(x, t, x_start)
|
| 385 |
+
|
| 386 |
+
return ModelPrediction(pred_noise, x_start)
|
| 387 |
+
|
| 388 |
+
def p_mean_variance(self, x, t, cond=None, clip_denoised = True):
|
| 389 |
+
preds = self.model_predictions(x, t, cond=cond, clip_x_start=False,)
|
| 390 |
+
x_start = preds.pred_x_start
|
| 391 |
+
|
| 392 |
+
if clip_denoised:
|
| 393 |
+
x_start.clamp_(-1., 1.)
|
| 394 |
+
|
| 395 |
+
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t)
|
| 396 |
+
return model_mean, posterior_variance, posterior_log_variance, x_start
|
| 397 |
+
|
| 398 |
+
@torch.no_grad()
|
| 399 |
+
def p_sample(self, x, t: int, cond=None):
|
| 400 |
+
b, *_, device = *x.shape, self.device
|
| 401 |
+
batched_times = torch.full((b,), t, device = device, dtype = torch.long)
|
| 402 |
+
model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, cond=cond, clip_denoised = False)
|
| 403 |
+
noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0
|
| 404 |
+
pred_img = model_mean + (0.5 * model_log_variance).exp() * noise
|
| 405 |
+
return pred_img, x_start
|
| 406 |
+
|
| 407 |
+
@torch.no_grad()
|
| 408 |
+
def p_sample_loop(self, shape, cond=None, return_all_timesteps = False):
|
| 409 |
+
batch, device = shape[0], self.device
|
| 410 |
+
|
| 411 |
+
frames_pred = torch.randn(shape, device = device)
|
| 412 |
+
imgs = [frames_pred]
|
| 413 |
+
|
| 414 |
+
for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps, disable=True):
|
| 415 |
+
frames_pred, _ = self.p_sample(frames_pred, t, cond=cond)
|
| 416 |
+
imgs.append(frames_pred)
|
| 417 |
+
|
| 418 |
+
ret = frames_pred if not return_all_timesteps else torch.stack(imgs, dim = 1)
|
| 419 |
+
return ret
|
| 420 |
+
|
| 421 |
+
@torch.no_grad()
|
| 422 |
+
def ddim_sample(self, shape, cond=None, return_all_timesteps = False):
|
| 423 |
+
batch, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective
|
| 424 |
+
device = self.device
|
| 425 |
+
times = torch.linspace(-1, total_timesteps - 1, steps = sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
|
| 426 |
+
times = list(reversed(times.int().tolist()))
|
| 427 |
+
time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
|
| 428 |
+
|
| 429 |
+
frames_pred = torch.randn(shape, device = device)
|
| 430 |
+
imgs = [frames_pred]
|
| 431 |
+
|
| 432 |
+
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step', disable=True):
|
| 433 |
+
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
|
| 434 |
+
pred_noise, x_start, *_ = self.model_predictions(
|
| 435 |
+
frames_pred,
|
| 436 |
+
time_cond,
|
| 437 |
+
cond = cond, #cond.copy(),
|
| 438 |
+
clip_x_start = False,
|
| 439 |
+
rederive_pred_noise = True
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
if time_next < 0:
|
| 443 |
+
frames_pred = x_start
|
| 444 |
+
imgs.append(frames_pred)
|
| 445 |
+
continue
|
| 446 |
+
|
| 447 |
+
alpha = self.alphas_cumprod[time]
|
| 448 |
+
alpha_next = self.alphas_cumprod[time_next]
|
| 449 |
+
|
| 450 |
+
sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
|
| 451 |
+
c = (1 - alpha_next - sigma ** 2).sqrt()
|
| 452 |
+
|
| 453 |
+
noise = torch.randn_like(frames_pred)
|
| 454 |
+
|
| 455 |
+
frames_pred = x_start * alpha_next.sqrt() + \
|
| 456 |
+
c * pred_noise + \
|
| 457 |
+
sigma * noise
|
| 458 |
+
|
| 459 |
+
imgs.append(frames_pred)
|
| 460 |
+
|
| 461 |
+
ret = frames_pred if not return_all_timesteps else torch.stack(imgs, dim = 1)
|
| 462 |
+
return ret
|
| 463 |
+
|
| 464 |
+
@torch.no_grad()
|
| 465 |
+
def sample(self, frames_in, return_all_timesteps = False):
|
| 466 |
+
assert frames_in.ndim == 5
|
| 467 |
+
B, T_in, C, H, W = frames_in.shape
|
| 468 |
+
device = self.device
|
| 469 |
+
|
| 470 |
+
backbone_output, conds, *_ = self.backbone(frames_in)
|
| 471 |
+
sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample
|
| 472 |
+
|
| 473 |
+
*_, c, h, w = conds.shape
|
| 474 |
+
tgt_shape = conds.reshape(B, -1, c, h, w).shape
|
| 475 |
+
ldm_pred = sample_fn(
|
| 476 |
+
tgt_shape,
|
| 477 |
+
cond=conds,
|
| 478 |
+
return_all_timesteps = return_all_timesteps
|
| 479 |
+
)
|
| 480 |
+
|
| 481 |
+
ldm_pred = rearrange(ldm_pred, 'b t c h w -> (b t) c h w')
|
| 482 |
+
frames_pred = self.backbone.vae.decode(ldm_pred)
|
| 483 |
+
frames_pred = rearrange(frames_pred, '(b t) c h w -> b t c h w', b=B)
|
| 484 |
+
return frames_pred, backbone_output
|
| 485 |
+
|
| 486 |
+
def predict(self, frames_in, compute_loss=False, **kwargs):
|
| 487 |
+
pred, mu = self.sample(frames_in=frames_in)
|
| 488 |
+
return pred, mu
|
| 489 |
+
|
| 490 |
+
def compute_loss(self, frames_in, frames_gt, validate=False):
|
| 491 |
+
compute_loss = True and (not validate)
|
| 492 |
+
B, T_in, C, H, W = frames_in.shape
|
| 493 |
+
T_out = frames_gt.shape[1]
|
| 494 |
+
device = frames_in.device
|
| 495 |
+
|
| 496 |
+
"""
|
| 497 |
+
Diffusion Loss
|
| 498 |
+
"""
|
| 499 |
+
backbone_output, conds = self.backbone(frames_in)
|
| 500 |
+
hid_gt, _ = self.backbone.vae.encode(
|
| 501 |
+
rearrange(frames_gt, 'b t c h w -> (b t) c h w')
|
| 502 |
+
)
|
| 503 |
+
hid_gt = rearrange(hid_gt, '(b t) c h w -> b t c h w', b=B)
|
| 504 |
+
t = torch.randint(0, self.num_timesteps, (B,), device=self.device).long()
|
| 505 |
+
if random.random() > 0.85: # Unconditional
|
| 506 |
+
conds = None
|
| 507 |
+
diff_loss = self.p_losses(hid_gt.detach(), t, cond=conds)
|
| 508 |
+
|
| 509 |
+
"""
|
| 510 |
+
Backbone Loss
|
| 511 |
+
"""
|
| 512 |
+
mu_loss = self.backbone._losses_(frames_in, frames_gt)
|
| 513 |
+
|
| 514 |
+
"""
|
| 515 |
+
VAE Loss
|
| 516 |
+
"""
|
| 517 |
+
ae_loss, kl_loss = self.backbone.vae._losses_(
|
| 518 |
+
rearrange(torch.cat((frames_in, frames_gt), dim=1), 'b t c h w -> (b t) c h w'),
|
| 519 |
+
rearrange(torch.cat((frames_in, frames_gt), dim=1), 'b t c h w -> (b t) c h w')
|
| 520 |
+
)
|
| 521 |
+
kl_weight = 1E-6
|
| 522 |
+
recon_loss = ae_loss + kl_weight*kl_loss
|
| 523 |
+
|
| 524 |
+
"""
|
| 525 |
+
Prior Loss at t=T [Noisy]
|
| 526 |
+
"""
|
| 527 |
+
hid_gt, _ = self.backbone.vae.encode(
|
| 528 |
+
rearrange(frames_gt, 'b t c h w -> (b t) c h w')
|
| 529 |
+
)
|
| 530 |
+
hid_gt = rearrange(hid_gt, '(b t) c h w -> b t c h w', b=B)
|
| 531 |
+
T = torch.ones((B,), device=self.device).long() * (self.num_timesteps - 1)
|
| 532 |
+
mu_noisy = extract(self.sqrt_alphas_cumprod, T, hid_gt.shape) * hid_gt
|
| 533 |
+
sigma_noisy = extract(self.sqrt_one_minus_alphas_cumprod, T, hid_gt.shape)
|
| 534 |
+
log_var_noisy = 2*torch.log(sigma_noisy)
|
| 535 |
+
prior_loss = self.kl_from_standard_normal(mu_noisy, log_var_noisy)
|
| 536 |
+
|
| 537 |
+
return recon_loss, mu_loss, diff_loss, prior_loss
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
def kl_from_standard_normal(self, mean, log_var):
|
| 541 |
+
kl = 0.5 * (log_var.exp() + mean.square() - 1.0 - log_var)
|
| 542 |
+
return kl.mean()
|
| 543 |
+
|
| 544 |
+
@autocast(enabled = False)
|
| 545 |
+
def q_sample(self, x_start, t, noise = None):
|
| 546 |
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
| 547 |
+
|
| 548 |
+
return (
|
| 549 |
+
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
| 550 |
+
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
| 551 |
+
)
|
| 552 |
+
|
| 553 |
+
def p_losses(self, x_start, t, noise=None, offset_noise_strength=None, cond=None):
|
| 554 |
+
b, T, c, h, w = x_start.shape
|
| 555 |
+
|
| 556 |
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
| 557 |
+
|
| 558 |
+
# offset noise - https://www.crosslabs.org/blog/diffusion-with-offset-noise
|
| 559 |
+
offset_noise_strength = default(offset_noise_strength, self.offset_noise_strength)
|
| 560 |
+
|
| 561 |
+
if offset_noise_strength > 0.:
|
| 562 |
+
offset_noise = torch.randn(x_start.shape[:2], device = self.device)
|
| 563 |
+
noise += offset_noise_strength * rearrange(offset_noise, 'b c -> b c 1 1')
|
| 564 |
+
|
| 565 |
+
# noise sample
|
| 566 |
+
x = self.q_sample(x_start=x_start, t=t, noise=noise) # Use q_sample here for updating: https://github.com/lucidrains/denoising-diffusion-pytorch/blob/main/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L763
|
| 567 |
+
|
| 568 |
+
model_out = self.diff_unet(x, t, conds=cond)
|
| 569 |
+
|
| 570 |
+
if self.objective == 'pred_noise':
|
| 571 |
+
target = noise
|
| 572 |
+
elif self.objective == 'pred_x0':
|
| 573 |
+
target = x_start
|
| 574 |
+
elif self.objective == 'pred_v':
|
| 575 |
+
v = self.predict_v(x_start, t, noise)
|
| 576 |
+
target = v
|
| 577 |
+
else:
|
| 578 |
+
raise ValueError(f'unknown objective {self.objective}')
|
| 579 |
+
|
| 580 |
+
loss = F.mse_loss(model_out, target, reduction = 'none') # (B, T, C, H, W)
|
| 581 |
+
loss = reduce(loss, 'b ... -> b', 'mean')
|
| 582 |
+
|
| 583 |
+
loss = loss * extract(self.loss_weight, t, loss.shape)
|
| 584 |
+
return loss.mean()
|
| 585 |
+
|
| 586 |
+
@torch.no_grad()
|
| 587 |
+
def forward(self, input_x, include_mu=False, **kwargs):
|
| 588 |
+
pred, mu = self.predict(input_x, compute_loss=False)
|
| 589 |
+
if include_mu:
|
| 590 |
+
return pred, mu
|
| 591 |
+
else:
|
| 592 |
+
return pred
|
| 593 |
+
|
| 594 |
+
from stldm.modules import SimVPV2_Model, VAE
|
| 595 |
+
def model_setup(model_config, print_info=False, cfg_str=None):
|
| 596 |
+
if print_info:
|
| 597 |
+
print('Setup the model with considering temporal attention be (BHW, T, C) ... ...')
|
| 598 |
+
print('Train it from end to end')
|
| 599 |
+
vp_config = model_config['vp_param']
|
| 600 |
+
ldm_config = model_config['stldm_param']
|
| 601 |
+
|
| 602 |
+
vpm = SimVPV2_Model(**vp_config)
|
| 603 |
+
ldm = LDM(**ldm_config)
|
| 604 |
+
model = GaussianDiffusion(vp_model=vpm, diffusion=ldm, **model_config['param'])
|
| 605 |
+
|
| 606 |
+
scheduler = guidance_scheduler(sampling_step=model_config['param']['timesteps'], const=cfg_str) if cfg_str is not None else None
|
| 607 |
+
model.setup_guidance(scheduler)
|
| 608 |
+
|
| 609 |
+
return model
|
| 610 |
+
|
| 611 |
+
def ae_setup(model_config):
|
| 612 |
+
vp_config = model_config['vp_param']
|
| 613 |
+
vpm = SimVPV2_Model(**vp_config)
|
| 614 |
+
ae = vpm.vae
|
| 615 |
+
return ae
|
| 616 |
+
|
| 617 |
+
def backbone_setup(model_config):
|
| 618 |
+
vp_config = model_config['vp_param']
|
| 619 |
+
vpm = SimVPV2_Model(**vp_config)
|
| 620 |
+
return vpm
|
stldm/stldm_spatial.py
ADDED
|
@@ -0,0 +1,593 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch, random
|
| 2 |
+
from torch import nn
|
| 3 |
+
from einops import rearrange
|
| 4 |
+
|
| 5 |
+
from stldm.submodules import *
|
| 6 |
+
|
| 7 |
+
class Down_Block(nn.Module):
|
| 8 |
+
def __init__(self, in_ch, hid_ch, out_ch, time_dim, is_last, patch_size=None, num_groups=8, heads=4, dim_head=32):
|
| 9 |
+
super(Down_Block, self).__init__()
|
| 10 |
+
self.block1 = ResnetBlock(dim=in_ch, dim_out=hid_ch, time_emb_dim=time_dim, groups=num_groups)
|
| 11 |
+
self.attn1 = Residual(PreNorm(hid_ch, Quadratic_SpatialAttention(dim=hid_ch, heads=heads, dim_head=dim_head))) if patch_size is None else Residual(PreNorm(hid_ch, Linear_SpatialAttention(dim=hid_ch, patch_size=patch_size, heads=heads, dim_head=dim_head)))
|
| 12 |
+
self.block2 = ResnetBlock(dim=hid_ch, dim_out=hid_ch, groups=num_groups)
|
| 13 |
+
self.attn2 = nn.Identity()
|
| 14 |
+
# self.attn2 = Residual(PreNorm(hid_ch, Quadratic_SpatialAttention(dim=hid_ch, heads=heads, dim_head=dim_head))) if patch_size is None else Residual(PreNorm(hid_ch, Linear_SpatialAttention(dim=hid_ch, patch_size=patch_size, heads=heads, dim_head=dim_head)))
|
| 15 |
+
self.last = Downsample2D(dim_in=hid_ch, dim_out=out_ch) if not is_last else ChannelConversion(hid_ch, out_ch)
|
| 16 |
+
|
| 17 |
+
def forward(self, x, time_emb, cond=None):
|
| 18 |
+
assert x.ndim==5
|
| 19 |
+
B, T, C, H, W = x.shape
|
| 20 |
+
|
| 21 |
+
x = x.reshape(B*T, C, H, W)
|
| 22 |
+
if cond is None:
|
| 23 |
+
cond = torch.zeros_like(x) # -> Unconditioning
|
| 24 |
+
|
| 25 |
+
time_emb = time_emb.unsqueeze(1) # From (B C) to (B 1 C)
|
| 26 |
+
time_emb = time_emb.repeat(1, T, 1)
|
| 27 |
+
time_emb = time_emb.reshape(B*T, -1)
|
| 28 |
+
|
| 29 |
+
out = torch.cat((x, cond), dim=1) # BT, 2C, H, W
|
| 30 |
+
out = self.block1(out, time_emb)
|
| 31 |
+
|
| 32 |
+
skip1 = self.attn1(out)
|
| 33 |
+
out = self.block2(skip1, time_emb)
|
| 34 |
+
|
| 35 |
+
skip2 = self.attn2(out)
|
| 36 |
+
|
| 37 |
+
out = self.last(skip2)
|
| 38 |
+
*_, c, h, w = out.shape
|
| 39 |
+
|
| 40 |
+
return out.reshape(B, T, c, h, w), skip1, skip2
|
| 41 |
+
|
| 42 |
+
class MidBlock(nn.Module):
|
| 43 |
+
def __init__(self, in_ch, time_dim, num_groups=8, heads=4, dim_head=32):
|
| 44 |
+
super(MidBlock, self).__init__()
|
| 45 |
+
self.block1 = ResnetBlock(dim=in_ch, dim_out=in_ch, time_emb_dim=time_dim, groups=num_groups)
|
| 46 |
+
self.attn1 = Residual(PreNorm(in_ch, Quadratic_SpatialAttention(dim=in_ch, heads=heads, dim_head=dim_head)))
|
| 47 |
+
self.block2 = ResnetBlock(dim=in_ch, dim_out=in_ch, time_emb_dim=time_dim, groups=num_groups)
|
| 48 |
+
self.attn2 = nn.Identity()
|
| 49 |
+
# self.attn2 = Residual(PreNorm(in_ch, Quadratic_SpatialAttention(dim=in_ch, heads=heads, dim_head=dim_head)))
|
| 50 |
+
self.block3 = ResnetBlock(dim=in_ch, dim_out=in_ch, time_emb_dim=time_dim, groups=num_groups)
|
| 51 |
+
|
| 52 |
+
def forward(self, x, time_emb, relative_pos=None):
|
| 53 |
+
assert x.ndim==5
|
| 54 |
+
B, T, C, H, W = x.shape
|
| 55 |
+
x = x.reshape(B*T, C, H, W)
|
| 56 |
+
|
| 57 |
+
time_emb = time_emb.unsqueeze(1) # From (B C) to (B 1 C)
|
| 58 |
+
time_emb = time_emb.repeat(1, T, 1)
|
| 59 |
+
time_emb = time_emb.reshape(B*T, -1)
|
| 60 |
+
|
| 61 |
+
out = self.block1(x, time_emb)
|
| 62 |
+
out = self.attn1(out)
|
| 63 |
+
out = self.block2(out, time_emb) # a little bit difference here
|
| 64 |
+
out = self.attn2(out)
|
| 65 |
+
out = self.block3(out, time_emb)
|
| 66 |
+
|
| 67 |
+
*_, c, h, w = out.shape
|
| 68 |
+
return out.reshape(B, T, c, h, w)
|
| 69 |
+
|
| 70 |
+
class Up_Block(nn.Module):
|
| 71 |
+
def __init__(self, in_chs, hid_ch, out_ch, is_last, time_dim, patch_size=None, num_groups=8, heads=4, dim_head=32):
|
| 72 |
+
super(Up_Block, self).__init__()
|
| 73 |
+
in_ch, skip_ch = in_chs
|
| 74 |
+
self.up = Upsample2D(dim_in=in_ch, dim_out=hid_ch) if not is_last else ChannelConversion(in_ch, hid_ch)
|
| 75 |
+
self.attn1 = Residual(PreNorm(hid_ch, Quadratic_SpatialAttention(dim=hid_ch, heads=heads, dim_head=dim_head) if patch_size is None else Linear_SpatialAttention(dim=hid_ch, patch_size=patch_size, heads=heads, dim_head=dim_head)))
|
| 76 |
+
self.block1 = ResnetBlock(dim=hid_ch+skip_ch, dim_out=hid_ch, time_emb_dim=time_dim, groups=num_groups)
|
| 77 |
+
self.attn2 = nn.Identity()
|
| 78 |
+
# self.attn2 = Residual(PreNorm(hid_ch, Quadratic_SpatialAttention(dim=hid_ch, heads=heads, dim_head=dim_head) if patch_size is None else Linear_SpatialAttention(dim=hid_ch, patch_size=patch_size, heads=heads, dim_head=dim_head)))
|
| 79 |
+
self.block2 = ResnetBlock(dim=hid_ch+skip_ch, dim_out=out_ch, time_emb_dim=time_dim, groups=num_groups)
|
| 80 |
+
|
| 81 |
+
def forward(self, x, time_emb, skip1, skip2):
|
| 82 |
+
assert x.ndim==5
|
| 83 |
+
B, T, C, H, W = x.shape
|
| 84 |
+
x = x.reshape(B*T, C, H, W)
|
| 85 |
+
|
| 86 |
+
time_emb = time_emb.unsqueeze(1) # From (B C) to (B 1 C)
|
| 87 |
+
time_emb = time_emb.repeat(1, T, 1)
|
| 88 |
+
time_emb = time_emb.reshape(B*T, -1)
|
| 89 |
+
|
| 90 |
+
out = self.up(x)
|
| 91 |
+
*_, c, h, w = out.shape
|
| 92 |
+
out = self.attn1(out)
|
| 93 |
+
|
| 94 |
+
out = torch.cat((out, skip2), dim=1)
|
| 95 |
+
out = self.block1(out, time_emb)
|
| 96 |
+
|
| 97 |
+
out = self.attn2(out)
|
| 98 |
+
|
| 99 |
+
out = torch.cat((out, skip2), dim=1)
|
| 100 |
+
out = self.block2(out, time_emb)
|
| 101 |
+
*_, c, h, w = out.shape
|
| 102 |
+
return out.reshape(B, T, c, h, w)
|
| 103 |
+
|
| 104 |
+
class LDM(nn.Module):
|
| 105 |
+
def __init__(self, in_ch, chs_mult:tuple, patch_size=None, num_groups=8, heads=4, dim_head=32, base_ch=64):
|
| 106 |
+
super(LDM, self).__init__()
|
| 107 |
+
# Time Embedding MLP
|
| 108 |
+
time_dim = 4*base_ch
|
| 109 |
+
fourier_dim = base_ch
|
| 110 |
+
self.time_mlp = Time_MLP(dim=base_ch, time_dim=time_dim, fourier_dim=fourier_dim)
|
| 111 |
+
|
| 112 |
+
ups, downs = [], []
|
| 113 |
+
conditions = []
|
| 114 |
+
|
| 115 |
+
layer_no = len(chs_mult)
|
| 116 |
+
chs = [in_ch, *map(lambda m: base_ch*m, chs_mult)]
|
| 117 |
+
ch_in, ch_out = chs[:-1], chs[1:]
|
| 118 |
+
up_in, up_out = list(reversed(ch_out)), list(reversed(ch_in))
|
| 119 |
+
|
| 120 |
+
patches = None if patch_size is None else [patch_size//(2**n) for n in range(layer_no)] # Patch Size should be 2^N
|
| 121 |
+
for n in range(layer_no):
|
| 122 |
+
downs.append(
|
| 123 |
+
Down_Block(in_ch=2*ch_in[n], hid_ch=ch_in[n], out_ch=ch_out[n], time_dim=time_dim, patch_size=None if patch_size is None else patches[n], is_last=(n==layer_no-1), num_groups=num_groups, heads=heads, dim_head=dim_head)
|
| 124 |
+
)
|
| 125 |
+
ups.append(
|
| 126 |
+
Up_Block(in_chs=(up_in[n], ch_in[-n-1]), hid_ch=up_in[n], out_ch=up_out[n], time_dim=time_dim, patch_size=None if patch_size is None else patches[layer_no-n-1], is_last=(n==0), num_groups=num_groups, heads=heads, dim_head=dim_head)
|
| 127 |
+
)
|
| 128 |
+
if n != -1:
|
| 129 |
+
conditions.append(
|
| 130 |
+
Downsample2D(dim_in=ch_in[n], dim_out=ch_out[n])
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
self.downs = nn.ModuleList(downs)
|
| 134 |
+
self.ups = nn.ModuleList(ups)
|
| 135 |
+
self.conditions = nn.ModuleList(conditions)
|
| 136 |
+
self.mid = MidBlock(in_ch=ch_out[-1], time_dim=time_dim, num_groups=num_groups, heads=heads, dim_head=dim_head)
|
| 137 |
+
|
| 138 |
+
def forward(self, x, time, conds=None):
|
| 139 |
+
t = self.time_mlp(time)
|
| 140 |
+
hids1, hids2 = [], []
|
| 141 |
+
|
| 142 |
+
for n, down_block in enumerate(self.downs):
|
| 143 |
+
x, skip1, skip2 = down_block(x, t, conds)
|
| 144 |
+
hids1.append(skip1)
|
| 145 |
+
hids2.append(skip2)
|
| 146 |
+
if conds is not None:
|
| 147 |
+
conds = self.conditions[n](conds)
|
| 148 |
+
|
| 149 |
+
out = self.mid(x, t)
|
| 150 |
+
|
| 151 |
+
for up_block in self.ups:
|
| 152 |
+
out = up_block(out, t, hids1.pop(), hids2.pop())
|
| 153 |
+
return out
|
| 154 |
+
|
| 155 |
+
# constants
|
| 156 |
+
from collections import namedtuple
|
| 157 |
+
from torch.cuda.amp import autocast
|
| 158 |
+
import torch.nn.functional as F
|
| 159 |
+
from einops import reduce
|
| 160 |
+
from tqdm.auto import tqdm
|
| 161 |
+
|
| 162 |
+
ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start'])
|
| 163 |
+
|
| 164 |
+
def identity(t, *args, **kwargs):
|
| 165 |
+
return t
|
| 166 |
+
|
| 167 |
+
def extract(a, t, x_shape):
|
| 168 |
+
b, *_ = t.shape
|
| 169 |
+
out = a.gather(-1, t)
|
| 170 |
+
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
| 171 |
+
|
| 172 |
+
def default(val, d):
|
| 173 |
+
if exists(val):
|
| 174 |
+
return val
|
| 175 |
+
return d() if callable(d) else d
|
| 176 |
+
|
| 177 |
+
def exists(x):
|
| 178 |
+
return x is not None
|
| 179 |
+
|
| 180 |
+
def guidance_scheduler(sampling_step: int, const: float):
|
| 181 |
+
return const*torch.ones(sampling_step)
|
| 182 |
+
|
| 183 |
+
class GaussianDiffusion(nn.Module):
|
| 184 |
+
def __init__(
|
| 185 |
+
self,
|
| 186 |
+
vp_model,
|
| 187 |
+
diffusion,
|
| 188 |
+
timesteps = 1000,
|
| 189 |
+
sampling_timesteps = None,
|
| 190 |
+
objective = 'pred_v',
|
| 191 |
+
beta_schedule = 'sigmoid',
|
| 192 |
+
schedule_fn_kwargs = dict(),
|
| 193 |
+
ddim_sampling_eta = 0.,
|
| 194 |
+
offset_noise_strength = 0., # https://www.crosslabs.org/blog/diffusion-with-offset-noise
|
| 195 |
+
min_snr_loss_weight = False, # https://arxiv.org/abs/2303.09556
|
| 196 |
+
min_snr_gamma = 5
|
| 197 |
+
):
|
| 198 |
+
super(GaussianDiffusion, self).__init__()
|
| 199 |
+
|
| 200 |
+
self.backbone = vp_model
|
| 201 |
+
self.diff_unet = diffusion
|
| 202 |
+
|
| 203 |
+
self.objective = objective
|
| 204 |
+
assert objective in {'pred_noise', 'pred_x0', 'pred_v'}, 'objective must be either pred_noise (predict noise) or pred_x0 (predict image start) or pred_v (predict v [v-parameterization as defined in appendix D of progressive distillation paper, used in imagen-video successfully])'
|
| 205 |
+
|
| 206 |
+
if beta_schedule == 'linear':
|
| 207 |
+
beta_schedule_fn = linear_beta_schedule
|
| 208 |
+
elif beta_schedule == 'cosine':
|
| 209 |
+
beta_schedule_fn = cosine_beta_schedule
|
| 210 |
+
elif beta_schedule == 'sigmoid':
|
| 211 |
+
beta_schedule_fn = sigmoid_beta_schedule
|
| 212 |
+
else:
|
| 213 |
+
raise ValueError(f'unknown beta schedule {beta_schedule}')
|
| 214 |
+
|
| 215 |
+
betas = beta_schedule_fn(timesteps, **schedule_fn_kwargs)
|
| 216 |
+
|
| 217 |
+
alphas = 1. - betas
|
| 218 |
+
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
| 219 |
+
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
|
| 220 |
+
|
| 221 |
+
timesteps, = betas.shape
|
| 222 |
+
self.num_timesteps = int(timesteps)
|
| 223 |
+
|
| 224 |
+
# sampling related parameters
|
| 225 |
+
|
| 226 |
+
self.sampling_timesteps = default(sampling_timesteps, timesteps) # default num sampling timesteps to number of timesteps at training
|
| 227 |
+
|
| 228 |
+
assert self.sampling_timesteps <= timesteps
|
| 229 |
+
self.is_ddim_sampling = self.sampling_timesteps < timesteps
|
| 230 |
+
self.ddim_sampling_eta = ddim_sampling_eta
|
| 231 |
+
|
| 232 |
+
# helper function to register buffer from float64 to float32
|
| 233 |
+
|
| 234 |
+
register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
|
| 235 |
+
|
| 236 |
+
register_buffer('betas', betas)
|
| 237 |
+
register_buffer('alphas_cumprod', alphas_cumprod)
|
| 238 |
+
register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
|
| 239 |
+
|
| 240 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
| 241 |
+
|
| 242 |
+
register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
|
| 243 |
+
register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
|
| 244 |
+
register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
|
| 245 |
+
register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
|
| 246 |
+
register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
|
| 247 |
+
|
| 248 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
| 249 |
+
|
| 250 |
+
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
|
| 251 |
+
|
| 252 |
+
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
| 253 |
+
|
| 254 |
+
register_buffer('posterior_variance', posterior_variance)
|
| 255 |
+
|
| 256 |
+
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
| 257 |
+
|
| 258 |
+
register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
|
| 259 |
+
register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
| 260 |
+
register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
| 261 |
+
|
| 262 |
+
# offset noise strength - in blogpost, they claimed 0.1 was ideal
|
| 263 |
+
|
| 264 |
+
self.offset_noise_strength = offset_noise_strength
|
| 265 |
+
|
| 266 |
+
# derive loss weight
|
| 267 |
+
# snr - signal noise ratio
|
| 268 |
+
|
| 269 |
+
snr = alphas_cumprod / (1 - alphas_cumprod)
|
| 270 |
+
|
| 271 |
+
# https://arxiv.org/abs/2303.09556
|
| 272 |
+
|
| 273 |
+
maybe_clipped_snr = snr.clone()
|
| 274 |
+
if min_snr_loss_weight:
|
| 275 |
+
maybe_clipped_snr.clamp_(max = min_snr_gamma)
|
| 276 |
+
|
| 277 |
+
if objective == 'pred_noise':
|
| 278 |
+
register_buffer('loss_weight', maybe_clipped_snr / snr)
|
| 279 |
+
elif objective == 'pred_x0':
|
| 280 |
+
register_buffer('loss_weight', maybe_clipped_snr)
|
| 281 |
+
elif objective == 'pred_v':
|
| 282 |
+
register_buffer('loss_weight', maybe_clipped_snr / (snr + 1))
|
| 283 |
+
|
| 284 |
+
@property
|
| 285 |
+
def device(self):
|
| 286 |
+
return self.betas.device
|
| 287 |
+
|
| 288 |
+
# CFG schdeuler => by taking pre-setting scheduler
|
| 289 |
+
def setup_guidance(self, scheduler):
|
| 290 |
+
if exists(scheduler):
|
| 291 |
+
self.CFG_sch = scheduler.to(self.device)
|
| 292 |
+
else:
|
| 293 |
+
self.CFG_sch = scheduler
|
| 294 |
+
|
| 295 |
+
def predict_start_from_noise(self, x_t, t, noise):
|
| 296 |
+
return (
|
| 297 |
+
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
|
| 298 |
+
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
def predict_noise_from_start(self, x_t, t, x0):
|
| 302 |
+
return (
|
| 303 |
+
(extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \
|
| 304 |
+
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
def predict_v(self, x_start, t, noise):
|
| 308 |
+
return (
|
| 309 |
+
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise -
|
| 310 |
+
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
def predict_start_from_v(self, x_t, t, v):
|
| 314 |
+
return (
|
| 315 |
+
extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
|
| 316 |
+
extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
def q_posterior(self, x_start, x_t, t):
|
| 320 |
+
posterior_mean = (
|
| 321 |
+
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
| 322 |
+
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
| 323 |
+
)
|
| 324 |
+
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
|
| 325 |
+
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
| 326 |
+
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
| 327 |
+
|
| 328 |
+
def model_predictions(self, x, t, cond, clip_x_start = False, rederive_pred_noise = False):
|
| 329 |
+
# print(t.device)
|
| 330 |
+
if exists(self.CFG_sch):
|
| 331 |
+
uncond = self.diff_unet(x, t, conds=None) #conds=torch.zeros_like(cond))
|
| 332 |
+
model_output = self.diff_unet(x, t, conds=cond)
|
| 333 |
+
time = int(t[0])
|
| 334 |
+
model_output = model_output - self.CFG_sch[time] * (uncond - model_output)
|
| 335 |
+
else:
|
| 336 |
+
model_output = self.diff_unet(x, t, conds=cond)
|
| 337 |
+
maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity
|
| 338 |
+
|
| 339 |
+
if self.objective == 'pred_noise':
|
| 340 |
+
pred_noise = model_output
|
| 341 |
+
x_start = self.predict_start_from_noise(x, t, pred_noise)
|
| 342 |
+
x_start = maybe_clip(x_start)
|
| 343 |
+
|
| 344 |
+
if clip_x_start and rederive_pred_noise:
|
| 345 |
+
pred_noise = self.predict_noise_from_start(x, t, x_start)
|
| 346 |
+
|
| 347 |
+
elif self.objective == 'pred_x0':
|
| 348 |
+
x_start = model_output
|
| 349 |
+
x_start = maybe_clip(x_start)
|
| 350 |
+
pred_noise = self.predict_noise_from_start(x, t, x_start)
|
| 351 |
+
|
| 352 |
+
elif self.objective == 'pred_v':
|
| 353 |
+
v = model_output
|
| 354 |
+
x_start = self.predict_start_from_v(x, t, v)
|
| 355 |
+
x_start = maybe_clip(x_start)
|
| 356 |
+
pred_noise = self.predict_noise_from_start(x, t, x_start)
|
| 357 |
+
|
| 358 |
+
return ModelPrediction(pred_noise, x_start)
|
| 359 |
+
|
| 360 |
+
def p_mean_variance(self, x, t, cond=None, clip_denoised = True):
|
| 361 |
+
preds = self.model_predictions(x, t, cond=cond, clip_x_start=False,)
|
| 362 |
+
x_start = preds.pred_x_start
|
| 363 |
+
|
| 364 |
+
if clip_denoised:
|
| 365 |
+
x_start.clamp_(-1., 1.)
|
| 366 |
+
|
| 367 |
+
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t)
|
| 368 |
+
return model_mean, posterior_variance, posterior_log_variance, x_start
|
| 369 |
+
|
| 370 |
+
@torch.no_grad()
|
| 371 |
+
def p_sample(self, x, t: int, cond=None):
|
| 372 |
+
b, *_, device = *x.shape, self.device
|
| 373 |
+
batched_times = torch.full((b,), t, device = device, dtype = torch.long)
|
| 374 |
+
model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, cond=cond, clip_denoised = False)
|
| 375 |
+
noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0
|
| 376 |
+
pred_img = model_mean + (0.5 * model_log_variance).exp() * noise
|
| 377 |
+
return pred_img, x_start
|
| 378 |
+
|
| 379 |
+
@torch.no_grad()
|
| 380 |
+
def p_sample_loop(self, shape, cond=None, return_all_timesteps = False):
|
| 381 |
+
batch, device = shape[0], self.device
|
| 382 |
+
|
| 383 |
+
frames_pred = torch.randn(shape, device = device)
|
| 384 |
+
imgs = [frames_pred]
|
| 385 |
+
|
| 386 |
+
for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps, disable=True):
|
| 387 |
+
frames_pred, _ = self.p_sample(frames_pred, t, cond=cond)
|
| 388 |
+
imgs.append(frames_pred)
|
| 389 |
+
|
| 390 |
+
ret = frames_pred if not return_all_timesteps else torch.stack(imgs, dim = 1)
|
| 391 |
+
return ret
|
| 392 |
+
|
| 393 |
+
@torch.no_grad()
|
| 394 |
+
def ddim_sample(self, shape, cond=None, return_all_timesteps = False):
|
| 395 |
+
batch, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective
|
| 396 |
+
device = self.device
|
| 397 |
+
times = torch.linspace(-1, total_timesteps - 1, steps = sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
|
| 398 |
+
times = list(reversed(times.int().tolist()))
|
| 399 |
+
time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
|
| 400 |
+
|
| 401 |
+
frames_pred = torch.randn(shape, device = device)
|
| 402 |
+
imgs = [frames_pred]
|
| 403 |
+
|
| 404 |
+
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step', disable=True):
|
| 405 |
+
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
|
| 406 |
+
pred_noise, x_start, *_ = self.model_predictions(
|
| 407 |
+
frames_pred,
|
| 408 |
+
time_cond,
|
| 409 |
+
cond = cond, #cond.copy(),
|
| 410 |
+
clip_x_start = False,
|
| 411 |
+
rederive_pred_noise = True
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
if time_next < 0:
|
| 415 |
+
frames_pred = x_start
|
| 416 |
+
imgs.append(frames_pred)
|
| 417 |
+
continue
|
| 418 |
+
|
| 419 |
+
alpha = self.alphas_cumprod[time]
|
| 420 |
+
alpha_next = self.alphas_cumprod[time_next]
|
| 421 |
+
|
| 422 |
+
sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
|
| 423 |
+
c = (1 - alpha_next - sigma ** 2).sqrt()
|
| 424 |
+
|
| 425 |
+
noise = torch.randn_like(frames_pred)
|
| 426 |
+
|
| 427 |
+
frames_pred = x_start * alpha_next.sqrt() + \
|
| 428 |
+
c * pred_noise + \
|
| 429 |
+
sigma * noise
|
| 430 |
+
|
| 431 |
+
imgs.append(frames_pred)
|
| 432 |
+
|
| 433 |
+
ret = frames_pred if not return_all_timesteps else torch.stack(imgs, dim = 1)
|
| 434 |
+
return ret
|
| 435 |
+
|
| 436 |
+
@torch.no_grad()
|
| 437 |
+
def sample(self, frames_in, return_all_timesteps = False):
|
| 438 |
+
assert frames_in.ndim == 5
|
| 439 |
+
B, T_in, C, H, W = frames_in.shape
|
| 440 |
+
device = self.device
|
| 441 |
+
|
| 442 |
+
backbone_output, conds, *_ = self.backbone(frames_in) # updated for Updated loss function on 03/07
|
| 443 |
+
sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample
|
| 444 |
+
|
| 445 |
+
*_, c, h, w = conds.shape
|
| 446 |
+
tgt_shape = conds.reshape(B, -1, c, h, w).shape
|
| 447 |
+
ldm_pred = sample_fn(
|
| 448 |
+
tgt_shape,
|
| 449 |
+
cond=conds,
|
| 450 |
+
return_all_timesteps = return_all_timesteps
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
ldm_pred = rearrange(ldm_pred, 'b t c h w -> (b t) c h w')
|
| 454 |
+
frames_pred = self.backbone.vae.decode(ldm_pred)
|
| 455 |
+
frames_pred = rearrange(frames_pred, '(b t) c h w -> b t c h w', b=B)
|
| 456 |
+
return frames_pred, backbone_output
|
| 457 |
+
|
| 458 |
+
def predict(self, frames_in, compute_loss=False, **kwargs):
|
| 459 |
+
pred, mu = self.sample(frames_in=frames_in)
|
| 460 |
+
return pred, mu
|
| 461 |
+
|
| 462 |
+
def compute_loss(self, frames_in, frames_gt, validate=False):
|
| 463 |
+
compute_loss = True and (not validate)
|
| 464 |
+
B, T_in, C, H, W = frames_in.shape
|
| 465 |
+
T_out = frames_gt.shape[1]
|
| 466 |
+
device = frames_in.device
|
| 467 |
+
|
| 468 |
+
"""
|
| 469 |
+
Diffusion Loss
|
| 470 |
+
"""
|
| 471 |
+
backbone_output, conds = self.backbone(frames_in)
|
| 472 |
+
hid_gt, _ = self.backbone.vae.encode(
|
| 473 |
+
rearrange(frames_gt, 'b t c h w -> (b t) c h w')
|
| 474 |
+
)
|
| 475 |
+
hid_gt = rearrange(hid_gt, '(b t) c h w -> b t c h w', b=B)
|
| 476 |
+
t = torch.randint(0, self.num_timesteps, (B,), device=self.device).long()
|
| 477 |
+
if random.random() > 0.85: # Unconditional
|
| 478 |
+
conds = None
|
| 479 |
+
diff_loss = self.p_losses(hid_gt.detach(), t, cond=conds)
|
| 480 |
+
|
| 481 |
+
"""
|
| 482 |
+
Backbone Loss
|
| 483 |
+
"""
|
| 484 |
+
mu_loss = self.backbone._losses_(frames_in, frames_gt)
|
| 485 |
+
|
| 486 |
+
"""
|
| 487 |
+
VAE Loss
|
| 488 |
+
"""
|
| 489 |
+
ae_loss, kl_loss = self.backbone.vae._losses_(
|
| 490 |
+
rearrange(torch.cat((frames_in, frames_gt), dim=1), 'b t c h w -> (b t) c h w'),
|
| 491 |
+
rearrange(torch.cat((frames_in, frames_gt), dim=1), 'b t c h w -> (b t) c h w')
|
| 492 |
+
)
|
| 493 |
+
kl_weight = 1E-6
|
| 494 |
+
recon_loss = ae_loss + kl_weight*kl_loss
|
| 495 |
+
|
| 496 |
+
"""
|
| 497 |
+
Prior Loss at t=T [Noisy]
|
| 498 |
+
"""
|
| 499 |
+
hid_gt, _ = self.backbone.vae.encode(
|
| 500 |
+
rearrange(frames_gt, 'b t c h w -> (b t) c h w')
|
| 501 |
+
)
|
| 502 |
+
hid_gt = rearrange(hid_gt, '(b t) c h w -> b t c h w', b=B)
|
| 503 |
+
T = torch.ones((B,), device=self.device).long() * (self.num_timesteps - 1)
|
| 504 |
+
mu_noisy = extract(self.sqrt_alphas_cumprod, T, hid_gt.shape) * hid_gt
|
| 505 |
+
sigma_noisy = extract(self.sqrt_one_minus_alphas_cumprod, T, hid_gt.shape)
|
| 506 |
+
log_var_noisy = 2*torch.log(sigma_noisy)
|
| 507 |
+
prior_loss = self.kl_from_standard_normal(mu_noisy, log_var_noisy)
|
| 508 |
+
|
| 509 |
+
return recon_loss, mu_loss, diff_loss, prior_loss
|
| 510 |
+
|
| 511 |
+
def kl_from_standard_normal(self, mean, log_var):
|
| 512 |
+
kl = 0.5 * (log_var.exp() + mean.square() - 1.0 - log_var)
|
| 513 |
+
return kl.mean()
|
| 514 |
+
|
| 515 |
+
@autocast(enabled = False)
|
| 516 |
+
def q_sample(self, x_start, t, noise = None):
|
| 517 |
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
| 518 |
+
|
| 519 |
+
return (
|
| 520 |
+
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
| 521 |
+
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
def p_losses(self, x_start, t, noise=None, offset_noise_strength=None, cond=None):
|
| 525 |
+
b, T, c, h, w = x_start.shape
|
| 526 |
+
|
| 527 |
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
| 528 |
+
|
| 529 |
+
# offset noise - https://www.crosslabs.org/blog/diffusion-with-offset-noise
|
| 530 |
+
offset_noise_strength = default(offset_noise_strength, self.offset_noise_strength)
|
| 531 |
+
|
| 532 |
+
if offset_noise_strength > 0.:
|
| 533 |
+
offset_noise = torch.randn(x_start.shape[:2], device = self.device)
|
| 534 |
+
noise += offset_noise_strength * rearrange(offset_noise, 'b c -> b c 1 1')
|
| 535 |
+
|
| 536 |
+
# noise sample
|
| 537 |
+
x = self.q_sample(x_start=x_start, t=t, noise=noise) # Use q_sample here for updating: https://github.com/lucidrains/denoising-diffusion-pytorch/blob/main/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L763
|
| 538 |
+
|
| 539 |
+
model_out = self.diff_unet(x, t, conds=cond)
|
| 540 |
+
|
| 541 |
+
if self.objective == 'pred_noise':
|
| 542 |
+
target = noise
|
| 543 |
+
elif self.objective == 'pred_x0':
|
| 544 |
+
target = x_start
|
| 545 |
+
elif self.objective == 'pred_v':
|
| 546 |
+
v = self.predict_v(x_start, t, noise)
|
| 547 |
+
target = v
|
| 548 |
+
else:
|
| 549 |
+
raise ValueError(f'unknown objective {self.objective}')
|
| 550 |
+
|
| 551 |
+
loss = F.mse_loss(model_out, target, reduction = 'none') # (B, T, C, H, W)
|
| 552 |
+
loss = reduce(loss, 'b ... -> b', 'mean')
|
| 553 |
+
|
| 554 |
+
loss = loss * extract(self.loss_weight, t, loss.shape)
|
| 555 |
+
return loss.mean()
|
| 556 |
+
|
| 557 |
+
@torch.no_grad()
|
| 558 |
+
def forward(self, input_x, include_mu=False, **kwargs):
|
| 559 |
+
pred, mu = self.predict(input_x, compute_loss=False)
|
| 560 |
+
if include_mu:
|
| 561 |
+
return pred, mu
|
| 562 |
+
else:
|
| 563 |
+
return pred
|
| 564 |
+
|
| 565 |
+
from stldm.modules import SimVPV2_Model, VAE
|
| 566 |
+
def model_setup(model_config, print_info=False, cfg_str=None):
|
| 567 |
+
if print_info:
|
| 568 |
+
print('Setup a Spatial diffusion with replacing a Temporal attention with Spatial attention')
|
| 569 |
+
print('This is a diffusion with the consideration of (BT, C, H, W)')
|
| 570 |
+
print('Train it from end to end')
|
| 571 |
+
|
| 572 |
+
vp_config = model_config['vp_param']
|
| 573 |
+
ldm_config = model_config['stldm_param']
|
| 574 |
+
|
| 575 |
+
vpm = SimVPV2_Model(**vp_config)
|
| 576 |
+
ldm = LDM(**ldm_config)
|
| 577 |
+
model = GaussianDiffusion(vp_model=vpm, diffusion=ldm, **model_config['param'])
|
| 578 |
+
|
| 579 |
+
scheduler = guidance_scheduler(sampling_step=model_config['param']['timesteps'], const=cfg_str) if cfg_str is not None else None
|
| 580 |
+
model.setup_guidance(scheduler)
|
| 581 |
+
|
| 582 |
+
return model
|
| 583 |
+
|
| 584 |
+
def ae_setup(model_config):
|
| 585 |
+
vp_config = model_config['vp_param']
|
| 586 |
+
vpm = SimVPV2_Model(**vp_config)
|
| 587 |
+
ae = vpm.vae
|
| 588 |
+
return ae
|
| 589 |
+
|
| 590 |
+
def backbone_setup(model_config):
|
| 591 |
+
vp_config = model_config['vp_param']
|
| 592 |
+
vpm = SimVPV2_Model(**vp_config)
|
| 593 |
+
return vpm
|
stldm/submodules.py
ADDED
|
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch, math
|
| 2 |
+
from torch import nn
|
| 3 |
+
from einops import rearrange
|
| 4 |
+
|
| 5 |
+
# building block modules
|
| 6 |
+
def exists(x):
|
| 7 |
+
return x is not None
|
| 8 |
+
|
| 9 |
+
class LayerNorm(nn.Module):
|
| 10 |
+
def __init__(self, dim):
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
|
| 13 |
+
|
| 14 |
+
def forward(self, x):
|
| 15 |
+
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
|
| 16 |
+
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
|
| 17 |
+
mean = torch.mean(x, dim = 1, keepdim = True)
|
| 18 |
+
return (x - mean) * (var + eps).rsqrt() * self.g
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class PreNorm(nn.Module):
|
| 22 |
+
def __init__(self, dim, fn):
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.fn = fn
|
| 25 |
+
self.norm = LayerNorm(dim)
|
| 26 |
+
|
| 27 |
+
def forward(self, x, *args, **kwargs):
|
| 28 |
+
x = self.norm(x)
|
| 29 |
+
return self.fn(x, *args, **kwargs)
|
| 30 |
+
|
| 31 |
+
class Residual(nn.Module):
|
| 32 |
+
def __init__(self, fn):
|
| 33 |
+
super().__init__()
|
| 34 |
+
self.fn = fn
|
| 35 |
+
|
| 36 |
+
def forward(self, x, *args, **kwargs):
|
| 37 |
+
return self.fn(x, *args, **kwargs) + x
|
| 38 |
+
|
| 39 |
+
class Block(nn.Module):
|
| 40 |
+
def __init__(self, dim, dim_out, groups = 8):
|
| 41 |
+
super().__init__()
|
| 42 |
+
self.proj = nn.Conv2d(dim, dim_out, 3, padding = 1)
|
| 43 |
+
if dim_out%groups != 0:
|
| 44 |
+
groups = 1
|
| 45 |
+
self.norm = nn.GroupNorm(groups, dim_out)
|
| 46 |
+
self.act = nn.SiLU()
|
| 47 |
+
|
| 48 |
+
def forward(self, x, scale_shift = None):
|
| 49 |
+
x = self.proj(x)
|
| 50 |
+
x = self.norm(x)
|
| 51 |
+
|
| 52 |
+
if exists(scale_shift):
|
| 53 |
+
scale, shift = scale_shift
|
| 54 |
+
x = x * (scale + 1) + shift
|
| 55 |
+
|
| 56 |
+
x = self.act(x)
|
| 57 |
+
return x
|
| 58 |
+
|
| 59 |
+
class ResnetBlock(nn.Module):
|
| 60 |
+
def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8):
|
| 61 |
+
super().__init__()
|
| 62 |
+
self.mlp = nn.Sequential(
|
| 63 |
+
nn.SiLU(),
|
| 64 |
+
nn.Linear(time_emb_dim, dim_out * 2)
|
| 65 |
+
) if exists(time_emb_dim) else None
|
| 66 |
+
|
| 67 |
+
self.block1 = Block(dim, dim_out, groups = groups)
|
| 68 |
+
self.block2 = Block(dim_out, dim_out, groups = groups)
|
| 69 |
+
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
| 70 |
+
|
| 71 |
+
def forward(self, x, time_emb = None):
|
| 72 |
+
|
| 73 |
+
scale_shift = None
|
| 74 |
+
if exists(self.mlp) and exists(time_emb):
|
| 75 |
+
time_emb = self.mlp(time_emb)
|
| 76 |
+
time_emb = rearrange(time_emb, 'b c -> b c 1 1')
|
| 77 |
+
scale_shift = time_emb.chunk(2, dim = 1)
|
| 78 |
+
|
| 79 |
+
h = self.block1(x, scale_shift = scale_shift)
|
| 80 |
+
|
| 81 |
+
h = self.block2(h)
|
| 82 |
+
|
| 83 |
+
return h + self.res_conv(x)
|
| 84 |
+
|
| 85 |
+
"""
|
| 86 |
+
Input Tensor and Output Tensor should be in the format of (BT, C, H, W) with # dims = 4
|
| 87 |
+
"""
|
| 88 |
+
class Linear_SpatialAttention(nn.Module):
|
| 89 |
+
def __init__(self, dim, patch_size, heads=4, dim_head=32):
|
| 90 |
+
super(Linear_SpatialAttention, self).__init__()
|
| 91 |
+
self.scale = dim_head ** -0.5
|
| 92 |
+
self.patch_size = patch_size
|
| 93 |
+
self.heads = heads
|
| 94 |
+
hidden_dim = dim_head*heads # No of Channel for (Q, K, V)
|
| 95 |
+
self.to_qkv = nn.Conv2d(dim, hidden_dim*3, kernel_size=1, padding=0, bias=False)
|
| 96 |
+
self.to_out = nn.Sequential(
|
| 97 |
+
nn.Conv2d(hidden_dim, dim, kernel_size=1),
|
| 98 |
+
LayerNorm(dim)
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
def forward(self, x):
|
| 102 |
+
assert x.ndim == 4
|
| 103 |
+
BT, C, H, W = x.shape
|
| 104 |
+
nh, nw = H//self.patch_size, W//self.patch_size
|
| 105 |
+
qkv = self.to_qkv(x).chunk(3, dim=1) # qkv tuple in (q, k , v)
|
| 106 |
+
# [B, Head × C, X × P, Y × P] -> [B, Head × X × Y, C, P × P]
|
| 107 |
+
q, k, v = map(lambda t: rearrange(t, 'b (h c) (nh ph) (nw pw) -> b (h nh nw) c (ph pw)', h=self.heads, ph=self.patch_size, pw=self.patch_size, nh=nh, nw=nw), qkv)
|
| 108 |
+
q = q.softmax(dim=-2)
|
| 109 |
+
k = k.softmax(dim=-1)
|
| 110 |
+
q = q*self.scale
|
| 111 |
+
|
| 112 |
+
context = torch.einsum('b h d n, b h e n -> b h d e', k, v)
|
| 113 |
+
out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
|
| 114 |
+
out = rearrange(out, 'b (h nh nw) c (ph pw) -> b (h c) (nh ph) (nw pw)', h=self.heads, ph=self.patch_size, pw=self.patch_size, nh=nh, nw=nw)
|
| 115 |
+
out = self.to_out(out)
|
| 116 |
+
return out
|
| 117 |
+
|
| 118 |
+
"""
|
| 119 |
+
Input Tensor and Output Tensor should be in the format of (B, T, C, H, W) with # dims = 5
|
| 120 |
+
"""
|
| 121 |
+
class Linear_TemporalAttention(nn.Module):
|
| 122 |
+
def __init__(self, dim, heads=4, dim_head=32):
|
| 123 |
+
super(Linear_TemporalAttention, self).__init__()
|
| 124 |
+
self.scale = dim_head ** -0.5
|
| 125 |
+
self.heads = heads
|
| 126 |
+
hidden_dim = dim_head*heads # No of Channel for (Q, K, V)
|
| 127 |
+
self.to_qkv = nn.Conv2d(dim, hidden_dim*3, kernel_size=1, padding=0, bias=False)
|
| 128 |
+
self.to_out = nn.Sequential(
|
| 129 |
+
nn.Conv2d(hidden_dim, dim, kernel_size=1),
|
| 130 |
+
LayerNorm(dim)
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
def forward(self, x):
|
| 134 |
+
assert x.ndim == 5
|
| 135 |
+
B, T, C, H, W = x.shape
|
| 136 |
+
x = x.reshape(B*T, C, H, W)
|
| 137 |
+
qkv = self.to_qkv(x).chunk(3, dim=1) # qkv tuple in (q, k , v)
|
| 138 |
+
# [B, Head × C, X × P, Y × P] -> [B, Head × X × Y, C, P × P]
|
| 139 |
+
q, k, v = map(lambda t: rearrange(t, '(b t) (h c) x y -> b (h x y) c t', h=self.heads, x=H, y=W, t=T), qkv)
|
| 140 |
+
q = q.softmax(dim=-2)
|
| 141 |
+
k = k.softmax(dim=-1)
|
| 142 |
+
q = q*self.scale
|
| 143 |
+
v /= (H*W)
|
| 144 |
+
|
| 145 |
+
context = torch.einsum('b h d n, b h e n -> b h d e', k, v)
|
| 146 |
+
out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
|
| 147 |
+
out = rearrange(out, 'b (h x y) c t -> (b t) (h c) x y', h=self.heads, x=H, y=W, t=T)
|
| 148 |
+
out = self.to_out(out)
|
| 149 |
+
return out.reshape(B, T, C, H, W)
|
| 150 |
+
|
| 151 |
+
# Does not Follow what suggested by the paper as could not ensure the spatial factor of 2
|
| 152 |
+
def Downsample2D(dim_in, dim_out):
|
| 153 |
+
return nn.Conv2d(dim_in, dim_out, kernel_size=(4, 4), stride=(2, 2), padding=(1,1))
|
| 154 |
+
|
| 155 |
+
def Upsample2D(dim_in, dim_out):
|
| 156 |
+
return nn.ConvTranspose2d(dim_in, dim_out, kernel_size=(4, 4), stride=(2, 2), padding=(1,1))
|
| 157 |
+
|
| 158 |
+
def ChannelConversion(dim_in, dim_out):
|
| 159 |
+
return nn.Conv2d(dim_in, dim_out, kernel_size=(3,3), padding=(1,1))
|
| 160 |
+
|
| 161 |
+
"""
|
| 162 |
+
Input Tensor and Output Tensor should be in the format of (BT, C, H, W) with # dims = 4
|
| 163 |
+
"""
|
| 164 |
+
class Quadratic_SpatialAttention(nn.Module):
|
| 165 |
+
def __init__(self, dim, heads=4, dim_head=32):
|
| 166 |
+
super(Quadratic_SpatialAttention, self).__init__()
|
| 167 |
+
self.scale = dim_head ** -0.5
|
| 168 |
+
self.heads = heads
|
| 169 |
+
hidden_dim = dim_head*heads # No of Channel for (Q, K, V)
|
| 170 |
+
self.to_qkv = nn.Conv2d(dim, hidden_dim*3, kernel_size=1, padding=0, bias=False)
|
| 171 |
+
self.to_out = nn.Sequential(
|
| 172 |
+
nn.Conv2d(hidden_dim, dim, kernel_size=1)
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
def forward(self, x):
|
| 176 |
+
assert x.ndim == 4
|
| 177 |
+
BT, C, H, W = x.shape
|
| 178 |
+
qkv = self.to_qkv(x).chunk(3, dim=1) # qkv tuple in (q, k , v)
|
| 179 |
+
# [B, Head × C, X × P, Y × P] -> [B, Head × X × Y, C, P × P]
|
| 180 |
+
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h=self.heads), qkv)
|
| 181 |
+
q = q*self.scale
|
| 182 |
+
|
| 183 |
+
sim = torch.einsum('b h d i, b h d j -> b h i j', q, k)
|
| 184 |
+
attn = sim.softmax(dim = -1)
|
| 185 |
+
out = torch.einsum('b h i j, b h d j -> b h i d', attn, v)
|
| 186 |
+
|
| 187 |
+
out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = H, y = W)
|
| 188 |
+
|
| 189 |
+
out = self.to_out(out)
|
| 190 |
+
return out
|
| 191 |
+
|
| 192 |
+
"""
|
| 193 |
+
Input Tensor and Output Tensor should be in the format of (B, T, C, H, W) with # dims = 5
|
| 194 |
+
"""
|
| 195 |
+
class Quadratic_TemporalAttention(nn.Module):
|
| 196 |
+
def __init__(self, dim, heads=4, dim_head=32):
|
| 197 |
+
super(Quadratic_TemporalAttention, self).__init__()
|
| 198 |
+
self.scale = dim_head ** -0.5
|
| 199 |
+
self.heads = heads
|
| 200 |
+
hidden_dim = dim_head*heads # No of Channel for (Q, K, V)
|
| 201 |
+
self.to_qkv = nn.Conv2d(dim, hidden_dim*3, kernel_size=1, padding=0, bias=False)
|
| 202 |
+
self.to_out = nn.Sequential(
|
| 203 |
+
nn.Conv2d(hidden_dim, dim, kernel_size=1),
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
def forward(self, x):
|
| 207 |
+
assert x.ndim == 5
|
| 208 |
+
B, T, C, H, W = x.shape
|
| 209 |
+
x = x.reshape(B*T, C, H, W)
|
| 210 |
+
qkv = self.to_qkv(x).chunk(3, dim=1) # qkv tuple in (q, k , v)
|
| 211 |
+
# [B, Head × C, X × P, Y × P] -> [B, Head × X × Y, C, P × P]
|
| 212 |
+
q, k, v = map(lambda t: rearrange(t, '(b t) (h c) x y -> b h (c x y) t', h=self.heads, x=H, y=W, t=T), qkv)
|
| 213 |
+
q = q*self.scale
|
| 214 |
+
|
| 215 |
+
sim = torch.einsum('b h d i, b h d j -> b h i j', q, k)
|
| 216 |
+
attn = sim.softmax(dim = -1)
|
| 217 |
+
out = torch.einsum('b h i j, b h d j -> b h i d', attn, v)
|
| 218 |
+
out = rearrange(out, 'b h t (c x y) -> (b t) (h c) x y', h=self.heads, x=H, y=W, t=T)
|
| 219 |
+
out = self.to_out(out)
|
| 220 |
+
return out.reshape(B, T, C, H, W)
|
| 221 |
+
|
| 222 |
+
"""
|
| 223 |
+
A series of functions required for Diffusion Model copied from DiffCast code
|
| 224 |
+
"""
|
| 225 |
+
def extract(a, t, x_shape):
|
| 226 |
+
b, *_ = t.shape
|
| 227 |
+
out = a.gather(-1, t)
|
| 228 |
+
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
| 229 |
+
|
| 230 |
+
def linear_beta_schedule(timesteps):
|
| 231 |
+
"""
|
| 232 |
+
linear schedule, proposed in original ddpm paper
|
| 233 |
+
"""
|
| 234 |
+
scale = 1000 / timesteps
|
| 235 |
+
beta_start = scale * 0.0001
|
| 236 |
+
beta_end = scale * 0.02
|
| 237 |
+
return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64)
|
| 238 |
+
|
| 239 |
+
def cosine_beta_schedule(timesteps, s = 0.008):
|
| 240 |
+
"""
|
| 241 |
+
cosine schedule
|
| 242 |
+
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
|
| 243 |
+
"""
|
| 244 |
+
steps = timesteps + 1
|
| 245 |
+
t = torch.linspace(0, timesteps, steps, dtype = torch.float64) / timesteps
|
| 246 |
+
alphas_cumprod = torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** 2
|
| 247 |
+
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
| 248 |
+
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
| 249 |
+
return torch.clip(betas, 0, 0.999)
|
| 250 |
+
|
| 251 |
+
def sigmoid_beta_schedule(timesteps, start = -3, end = 3, tau = 1, clamp_min = 1e-5):
|
| 252 |
+
"""
|
| 253 |
+
sigmoid schedule
|
| 254 |
+
proposed in https://arxiv.org/abs/2212.11972 - Figure 8
|
| 255 |
+
better for images > 64x64, when used during training
|
| 256 |
+
"""
|
| 257 |
+
steps = timesteps + 1
|
| 258 |
+
t = torch.linspace(0, timesteps, steps, dtype = torch.float64) / timesteps
|
| 259 |
+
v_start = torch.tensor(start / tau).sigmoid()
|
| 260 |
+
v_end = torch.tensor(end / tau).sigmoid()
|
| 261 |
+
alphas_cumprod = (-((t * (end - start) + start) / tau).sigmoid() + v_end) / (v_end - v_start)
|
| 262 |
+
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
| 263 |
+
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
| 264 |
+
return torch.clip(betas, 0, 0.999)
|
| 265 |
+
|
| 266 |
+
# sinusoidal positional embeds
|
| 267 |
+
class SinusoidalPosEmb(nn.Module):
|
| 268 |
+
def __init__(self, dim):
|
| 269 |
+
super(SinusoidalPosEmb, self).__init__()
|
| 270 |
+
self.dim = dim
|
| 271 |
+
|
| 272 |
+
def forward(self, x):
|
| 273 |
+
device = x.device
|
| 274 |
+
half_dim = self.dim // 2
|
| 275 |
+
emb = math.log(10000) / (half_dim - 1)
|
| 276 |
+
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
| 277 |
+
emb = x[:, None] * emb[None, :]
|
| 278 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
| 279 |
+
return emb
|
| 280 |
+
|
| 281 |
+
class Time_MLP(nn.Module):
|
| 282 |
+
def __init__(self, dim, time_dim, fourier_dim=32):
|
| 283 |
+
super(Time_MLP, self).__init__()
|
| 284 |
+
self.mlp = nn.Sequential(
|
| 285 |
+
SinusoidalPosEmb(fourier_dim),
|
| 286 |
+
nn.Linear(fourier_dim, time_dim),
|
| 287 |
+
nn.GELU(),
|
| 288 |
+
nn.Linear(time_dim, time_dim)
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
def forward(self, x):
|
| 292 |
+
return self.mlp(x)
|
| 293 |
+
|
| 294 |
+
class RelativePositionBias(nn.Module):
|
| 295 |
+
def __init__(
|
| 296 |
+
self,
|
| 297 |
+
heads = 8,
|
| 298 |
+
num_buckets = 32,
|
| 299 |
+
max_distance = 128
|
| 300 |
+
):
|
| 301 |
+
super().__init__()
|
| 302 |
+
self.num_buckets = num_buckets
|
| 303 |
+
self.max_distance = max_distance
|
| 304 |
+
self.relative_attention_bias = nn.Embedding(num_buckets, heads)
|
| 305 |
+
|
| 306 |
+
@staticmethod
|
| 307 |
+
def _relative_position_bucket(relative_position, num_buckets = 32, max_distance = 128):
|
| 308 |
+
ret = 0
|
| 309 |
+
n = -relative_position
|
| 310 |
+
|
| 311 |
+
num_buckets //= 2
|
| 312 |
+
ret += (n < 0).long() * num_buckets
|
| 313 |
+
n = torch.abs(n)
|
| 314 |
+
|
| 315 |
+
max_exact = num_buckets // 2
|
| 316 |
+
is_small = n < max_exact
|
| 317 |
+
|
| 318 |
+
val_if_large = max_exact + (
|
| 319 |
+
torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
|
| 320 |
+
).long()
|
| 321 |
+
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
|
| 322 |
+
|
| 323 |
+
ret += torch.where(is_small, n, val_if_large)
|
| 324 |
+
return ret
|
| 325 |
+
|
| 326 |
+
def forward(self, n, device):
|
| 327 |
+
q_pos = torch.arange(n, dtype = torch.long, device = device)
|
| 328 |
+
k_pos = torch.arange(n, dtype = torch.long, device = device)
|
| 329 |
+
rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
|
| 330 |
+
rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance)
|
| 331 |
+
values = self.relative_attention_bias(rp_bucket)
|
| 332 |
+
return rearrange(values, 'i j h -> h i j')
|
| 333 |
+
|
| 334 |
+
"""
|
| 335 |
+
Input Tensor and Output Tensor should be in the format of (B, T, C, H, W) with # dims = 5
|
| 336 |
+
"""
|
| 337 |
+
class TemporalAttention_Pos(nn.Module):
|
| 338 |
+
def __init__(self, dim, heads=4, dim_head=32):
|
| 339 |
+
super(TemporalAttention_Pos, self).__init__()
|
| 340 |
+
self.scale = dim_head ** -0.5
|
| 341 |
+
self.heads = heads
|
| 342 |
+
hidden_dim = dim_head*heads # No of Channel for (Q, K, V)
|
| 343 |
+
self.to_qkv = nn.Conv2d(dim, hidden_dim*3, kernel_size=1, padding=0)
|
| 344 |
+
self.to_out = nn.Sequential(
|
| 345 |
+
nn.Conv2d(hidden_dim, dim, kernel_size=1),
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
def forward(self, x, rel_pos=None):
|
| 349 |
+
assert x.ndim == 5
|
| 350 |
+
B, T, C, H, W = x.shape
|
| 351 |
+
x = x.reshape(B*T, C, H, W)
|
| 352 |
+
qkv = self.to_qkv(x).chunk(3, dim=1) # qkv tuple in (q, k , v)
|
| 353 |
+
# [B, Head × C, X × P, Y × P] -> [B, Head × X × Y, C, P × P]
|
| 354 |
+
q, k, v = map(lambda t: rearrange(t, '(b t) (h c) x y -> (b x y) h c t', h=self.heads, x=H, y=W, t=T), qkv)
|
| 355 |
+
q = q*self.scale
|
| 356 |
+
|
| 357 |
+
sim = torch.einsum('b h d i, b h d j -> b h i j', q, k)
|
| 358 |
+
if rel_pos is not None:
|
| 359 |
+
sim += rel_pos
|
| 360 |
+
attn = sim.softmax(dim = -1)
|
| 361 |
+
out = torch.einsum('b h i j, b h d j -> b h i d', attn, v)
|
| 362 |
+
out = rearrange(out, '(b x y) h t c -> (b t) (h c) x y', h=self.heads, x=H, y=W, t=T)
|
| 363 |
+
out = self.to_out(out)
|
| 364 |
+
return out.reshape(B, T, C, H, W)
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
class TemporalAttention(nn.Module):
|
| 368 |
+
def __init__(self, dim, heads=4, dim_head=32):
|
| 369 |
+
super(TemporalAttention, self).__init__()
|
| 370 |
+
self.scale = dim_head ** -0.5
|
| 371 |
+
self.heads = heads
|
| 372 |
+
hidden_dim = dim_head*heads
|
| 373 |
+
self.to_k = nn.Linear(dim, hidden_dim, bias=False)
|
| 374 |
+
self.to_q = nn.Linear(dim, hidden_dim, bias=False)
|
| 375 |
+
self.to_v = nn.Linear(dim, hidden_dim, bias=False)
|
| 376 |
+
self.to_out = nn.Linear(hidden_dim, dim)
|
| 377 |
+
|
| 378 |
+
def forward(self, x):
|
| 379 |
+
assert x.ndim == 5
|
| 380 |
+
B, T, C, H, W = x.shape
|
| 381 |
+
x = rearrange(x, 'b t c h w -> b (h w) t c')
|
| 382 |
+
|
| 383 |
+
q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)
|
| 384 |
+
q = rearrange(q, '... n (h d) -> ... h n d', h=self.heads) # B (H W) Head T Dim
|
| 385 |
+
k = rearrange(k, '... n (h d) -> ... h n d', h=self.heads)
|
| 386 |
+
v = rearrange(v, '... n (h d) -> ... h n d', h=self.heads)
|
| 387 |
+
q = q*self.scale
|
| 388 |
+
|
| 389 |
+
sim = torch.einsum('... h i d, ... h j d -> ... h i j', q, k)
|
| 390 |
+
attn = sim.softmax(dim=-1)
|
| 391 |
+
out = torch.einsum('... h i j, ... h j d -> ... h i d', attn, v)
|
| 392 |
+
out = rearrange(out, '... h i d -> ... i (h d)', h=self.heads)
|
| 393 |
+
out = self.to_out(out)
|
| 394 |
+
out = rearrange(out, 'b (h w) t c -> b t c h w', h=H, w=W)
|
| 395 |
+
return out
|
utilspp.py
ADDED
|
@@ -0,0 +1,566 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
import lpips as lp
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import torchmetrics
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
from bisect import bisect_right
|
| 9 |
+
import torchvision.transforms as T
|
| 10 |
+
from torch import nn
|
| 11 |
+
|
| 12 |
+
from matplotlib.colors import ListedColormap, BoundaryNorm
|
| 13 |
+
from matplotlib.lines import Line2D
|
| 14 |
+
|
| 15 |
+
from data import dutils
|
| 16 |
+
|
| 17 |
+
# =======================================================================
|
| 18 |
+
# Scheduler Helper Function
|
| 19 |
+
# =======================================================================
|
| 20 |
+
|
| 21 |
+
class SequentialLR(torch.optim.lr_scheduler._LRScheduler):
|
| 22 |
+
"""Receives the list of schedulers that is expected to be called sequentially during
|
| 23 |
+
optimization process and milestone points that provides exact intervals to reflect
|
| 24 |
+
which scheduler is supposed to be called at a given epoch.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
schedulers (list): List of chained schedulers.
|
| 28 |
+
milestones (list): List of integers that reflects milestone points.
|
| 29 |
+
|
| 30 |
+
Example:
|
| 31 |
+
>>> # Assuming optimizer uses lr = 1. for all groups
|
| 32 |
+
>>> # lr = 0.1 if epoch == 0
|
| 33 |
+
>>> # lr = 0.1 if epoch == 1
|
| 34 |
+
>>> # lr = 0.9 if epoch == 2
|
| 35 |
+
>>> # lr = 0.81 if epoch == 3
|
| 36 |
+
>>> # lr = 0.729 if epoch == 4
|
| 37 |
+
>>> scheduler1 = ConstantLR(self.opt, factor=0.1, total_iters=2)
|
| 38 |
+
>>> scheduler2 = ExponentialLR(self.opt, gamma=0.9)
|
| 39 |
+
>>> scheduler = SequentialLR(self.opt, schedulers=[scheduler1, scheduler2], milestones=[2])
|
| 40 |
+
>>> for epoch in range(100):
|
| 41 |
+
>>> train(...)
|
| 42 |
+
>>> validate(...)
|
| 43 |
+
>>> scheduler.step()
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
def __init__(self, optimizer, schedulers, milestones, last_epoch=-1, verbose=False):
|
| 47 |
+
for scheduler_idx in range(1, len(schedulers)):
|
| 48 |
+
if (schedulers[scheduler_idx].optimizer != schedulers[0].optimizer):
|
| 49 |
+
raise ValueError(
|
| 50 |
+
"Sequential Schedulers expects all schedulers to belong to the same optimizer, but "
|
| 51 |
+
"got schedulers at index {} and {} to be different".format(0, scheduler_idx)
|
| 52 |
+
)
|
| 53 |
+
if (len(milestones) != len(schedulers) - 1):
|
| 54 |
+
raise ValueError(
|
| 55 |
+
"Sequential Schedulers expects number of schedulers provided to be one more "
|
| 56 |
+
"than the number of milestone points, but got number of schedulers {} and the "
|
| 57 |
+
"number of milestones to be equal to {}".format(len(schedulers), len(milestones))
|
| 58 |
+
)
|
| 59 |
+
self.optimizer = optimizer
|
| 60 |
+
self._schedulers = schedulers
|
| 61 |
+
self._milestones = milestones
|
| 62 |
+
self.last_epoch = last_epoch + 1
|
| 63 |
+
|
| 64 |
+
def step(self, ref=None):
|
| 65 |
+
self.last_epoch += 1
|
| 66 |
+
idx = bisect_right(self._milestones, self.last_epoch)
|
| 67 |
+
if idx > 0 and self._milestones[idx - 1] == self.last_epoch:
|
| 68 |
+
self._schedulers[idx].step(0)
|
| 69 |
+
else:
|
| 70 |
+
# Check HERE
|
| 71 |
+
if isinstance(self._schedulers[idx], torch.optim.lr_scheduler.ReduceLROnPlateau):
|
| 72 |
+
self._schedulers[idx].step(ref)
|
| 73 |
+
else:
|
| 74 |
+
self._schedulers[idx].step()
|
| 75 |
+
|
| 76 |
+
def state_dict(self):
|
| 77 |
+
"""Returns the state of the scheduler as a :class:`dict`.
|
| 78 |
+
|
| 79 |
+
It contains an entry for every variable in self.__dict__ which
|
| 80 |
+
is not the optimizer.
|
| 81 |
+
The wrapped scheduler states will also be saved.
|
| 82 |
+
"""
|
| 83 |
+
state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', '_schedulers')}
|
| 84 |
+
state_dict['_schedulers'] = [None] * len(self._schedulers)
|
| 85 |
+
|
| 86 |
+
for idx, s in enumerate(self._schedulers):
|
| 87 |
+
state_dict['_schedulers'][idx] = s.state_dict()
|
| 88 |
+
|
| 89 |
+
return state_dict
|
| 90 |
+
|
| 91 |
+
def load_state_dict(self, state_dict):
|
| 92 |
+
"""Loads the schedulers state.
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
state_dict (dict): scheduler state. Should be an object returned
|
| 96 |
+
from a call to :meth:`state_dict`.
|
| 97 |
+
"""
|
| 98 |
+
_schedulers = state_dict.pop('_schedulers')
|
| 99 |
+
self.__dict__.update(state_dict)
|
| 100 |
+
# Restore state_dict keys in order to prevent side effects
|
| 101 |
+
# https://github.com/pytorch/pytorch/issues/32756
|
| 102 |
+
state_dict['_schedulers'] = _schedulers
|
| 103 |
+
|
| 104 |
+
for idx, s in enumerate(_schedulers):
|
| 105 |
+
self._schedulers[idx].load_state_dict(s)
|
| 106 |
+
|
| 107 |
+
def warmup_lambda(warmup_steps, min_lr_ratio=0.1):
|
| 108 |
+
def ret_lambda(epoch):
|
| 109 |
+
if epoch <= warmup_steps:
|
| 110 |
+
return min_lr_ratio + (1.0 - min_lr_ratio) * epoch / warmup_steps
|
| 111 |
+
else:
|
| 112 |
+
return 1.0
|
| 113 |
+
return ret_lambda
|
| 114 |
+
|
| 115 |
+
# =======================================================================
|
| 116 |
+
# Utils in utils :)
|
| 117 |
+
# =======================================================================
|
| 118 |
+
def to_cpu_tensor(*args):
|
| 119 |
+
'''
|
| 120 |
+
Input arbitrary number of array/tensors, each will be converted to CPU torch.Tensor
|
| 121 |
+
'''
|
| 122 |
+
out = []
|
| 123 |
+
for tensor in args:
|
| 124 |
+
if type(tensor) is np.ndarray:
|
| 125 |
+
tensor = torch.Tensor(tensor)
|
| 126 |
+
if type(tensor) is torch.Tensor:
|
| 127 |
+
tensor = tensor.cpu()
|
| 128 |
+
out.append(tensor)
|
| 129 |
+
# single value input: return single value output
|
| 130 |
+
if len(out) == 1:
|
| 131 |
+
return out[0]
|
| 132 |
+
return out
|
| 133 |
+
|
| 134 |
+
def merge_leading_dims(tensor, n=2):
|
| 135 |
+
'''
|
| 136 |
+
Merge the first N dimension of a tensor
|
| 137 |
+
'''
|
| 138 |
+
return tensor.reshape((-1, *tensor.shape[n:]))
|
| 139 |
+
|
| 140 |
+
# =======================================================================
|
| 141 |
+
# Model Preparation, saving & loading (copied from utils.py)
|
| 142 |
+
# =======================================================================
|
| 143 |
+
def build_model_name(model_type, model_config):
|
| 144 |
+
'''
|
| 145 |
+
Build the model name (without extension)
|
| 146 |
+
'''
|
| 147 |
+
model_name = model_type + '_'
|
| 148 |
+
for k, v in model_config.items():
|
| 149 |
+
model_name += k
|
| 150 |
+
if type(v) is list or type(v) is tuple:
|
| 151 |
+
model_name += '-'
|
| 152 |
+
for i, item in enumerate(v):
|
| 153 |
+
model_name += (str(item) if type(item) is not bool else '') + ('-' if i < len(v)-1 else '')
|
| 154 |
+
else:
|
| 155 |
+
model_name += (('-' + str(v)) if type(v) is not bool else '')
|
| 156 |
+
model_name += '_'
|
| 157 |
+
return model_name[:-1]
|
| 158 |
+
|
| 159 |
+
def build_model_path(base_dir, dataset_type, model_type, timestamp=None):
|
| 160 |
+
if timestamp is None:
|
| 161 |
+
return os.path.join(base_dir, dataset_type, model_type)
|
| 162 |
+
elif timestamp == True:
|
| 163 |
+
return os.path.join(base_dir, dataset_type, model_type, pd.Timestamp.now().strftime('%Y%m%d%H%M%S'))
|
| 164 |
+
return os.path.join(base_dir, dataset_type, model_type, timestamp)
|
| 165 |
+
|
| 166 |
+
# =======================================================================
|
| 167 |
+
# Preprocess Function for Loading HKO-7 dataset
|
| 168 |
+
# =======================================================================
|
| 169 |
+
|
| 170 |
+
def hko7_preprocess(x_seq, x_mask, dt_clip, args):
|
| 171 |
+
resize = args.resize if 'resize' in args else x_seq.shape[-1]
|
| 172 |
+
seq_len = args.seq_len if 'seq_len' in args else 5
|
| 173 |
+
|
| 174 |
+
# post-process on HKO-10
|
| 175 |
+
x_seq = x_seq.transpose((1, 0, 2, 3, 4)) / 255. # => (batch_size, seq_length, 1, 480, 480)
|
| 176 |
+
if 'scale' in args and args.scale == 'non-linear':
|
| 177 |
+
x_seq = dutils.linear_to_nonlinear_batched(x_seq, dt_clip)
|
| 178 |
+
else:
|
| 179 |
+
x_seq = dutils.nonlinear_to_linear_batched(x_seq, dt_clip)
|
| 180 |
+
|
| 181 |
+
b, t, c, h, w = x_seq.shape
|
| 182 |
+
assert c == 1, f'# channels ({c}) != 1'
|
| 183 |
+
|
| 184 |
+
# resize (downsample) the images if necessary
|
| 185 |
+
x_seq = torch.Tensor(x_seq).float().reshape((b*t, c, h, w))
|
| 186 |
+
if resize != h:
|
| 187 |
+
tform = T.Compose([
|
| 188 |
+
T.ToPILImage(),
|
| 189 |
+
T.Resize(resize),
|
| 190 |
+
T.ToTensor(),
|
| 191 |
+
])
|
| 192 |
+
else:
|
| 193 |
+
tform = T.Compose([])
|
| 194 |
+
|
| 195 |
+
x_seq = torch.stack([tform(x_frame) for x_frame in x_seq], dim=0)
|
| 196 |
+
x_seq = x_seq.reshape((b, t, c, resize, resize))
|
| 197 |
+
|
| 198 |
+
x, y = x_seq[:, :seq_len], x_seq[:, seq_len:]
|
| 199 |
+
return x, y
|
| 200 |
+
|
| 201 |
+
# =======================================================================
|
| 202 |
+
# Evaluation Metrics-Related
|
| 203 |
+
# =======================================================================
|
| 204 |
+
|
| 205 |
+
mae = lambda *args: torch.nn.functional.l1_loss(*args).cpu().detach().numpy()
|
| 206 |
+
mse = lambda *args: torch.nn.functional.mse_loss(*args).cpu().detach().numpy()
|
| 207 |
+
|
| 208 |
+
def ssim(y_pred, y):
|
| 209 |
+
y, y_pred = to_cpu_tensor(y, y_pred)
|
| 210 |
+
b, t, c, h, w = y.shape
|
| 211 |
+
y = y.reshape((b*t, c, h, w))
|
| 212 |
+
y_pred = y_pred.reshape((b*t, c, h, w))
|
| 213 |
+
# to further ensure any of the input is not negative
|
| 214 |
+
y = torch.clamp(y, 0, 1)
|
| 215 |
+
y_pred = torch.clamp(y_pred, 0, 1)
|
| 216 |
+
return torchmetrics.image.StructuralSimilarityIndexMeasure(data_range=1.0)(y_pred, y)
|
| 217 |
+
|
| 218 |
+
def psnr(y_pred, y):
|
| 219 |
+
y, y_pred = to_cpu_tensor(y, y_pred)
|
| 220 |
+
b, t, c, h, w = y.shape
|
| 221 |
+
y = y.reshape((b*t, c, h, w))
|
| 222 |
+
y_pred = y_pred.reshape((b*t, c, h, w))
|
| 223 |
+
acc_score = 0
|
| 224 |
+
for i in range(b*t):
|
| 225 |
+
acc_score += torchmetrics.image.PeakSignalNoiseRatio(data_range=1.0)(y_pred[i], y[i]) / (b*t)
|
| 226 |
+
return acc_score
|
| 227 |
+
|
| 228 |
+
GLOBAL_LPIPS_OBJ = None # a static variable
|
| 229 |
+
def lpips64(y_pred, y, net='vgg'):
|
| 230 |
+
# convert the image range into [-1, 1], assuming the input range to be [0, 1]
|
| 231 |
+
y = merge_leading_dims(y)
|
| 232 |
+
y_pred = merge_leading_dims(y_pred)
|
| 233 |
+
|
| 234 |
+
y = torch.nn.functional.interpolate(y, (64, 64), mode='bicubic').clamp(0,1)
|
| 235 |
+
y_pred = torch.nn.functional.interpolate(y_pred, (64, 64), mode='bicubic').clamp(0,1)
|
| 236 |
+
|
| 237 |
+
y = (2 * y - 1)
|
| 238 |
+
y_pred = (2 * y_pred - 1)
|
| 239 |
+
global GLOBAL_LPIPS_OBJ
|
| 240 |
+
if GLOBAL_LPIPS_OBJ is None:
|
| 241 |
+
GLOBAL_LPIPS_OBJ = lp.LPIPS(net=net).to(y.device)
|
| 242 |
+
return GLOBAL_LPIPS_OBJ(y_pred, y).mean()
|
| 243 |
+
|
| 244 |
+
def tfpn(y_pred, y, threshold, radius=1):
|
| 245 |
+
'''
|
| 246 |
+
convert to cpu, and merge the first two dimensions
|
| 247 |
+
'''
|
| 248 |
+
y = merge_leading_dims(y)
|
| 249 |
+
y_pred = merge_leading_dims(y_pred)
|
| 250 |
+
with torch.no_grad():
|
| 251 |
+
if radius > 1:
|
| 252 |
+
pool = nn.MaxPool2d(radius)
|
| 253 |
+
y = pool(y)
|
| 254 |
+
y_pred = pool(y_pred)
|
| 255 |
+
y = torch.where(y >= threshold, 1, 0)
|
| 256 |
+
y_pred = torch.where(y_pred >= threshold, 1, 0)
|
| 257 |
+
mat = torchmetrics.functional.confusion_matrix(y_pred, y, task='binary', threshold=threshold)
|
| 258 |
+
(tn, fp), (fn, tp) = to_cpu_tensor(mat)
|
| 259 |
+
return tp, tn, fp, fn
|
| 260 |
+
|
| 261 |
+
def tfpn_pool(y_pred, y, threshold, radius):
|
| 262 |
+
y_pred = merge_leading_dims(y_pred)
|
| 263 |
+
y = merge_leading_dims(y)
|
| 264 |
+
pool = nn.MaxPool2d(radius, stride=radius//4 if radius//4 > 0 else radius)
|
| 265 |
+
with torch.no_grad():
|
| 266 |
+
y = torch.where(y>=threshold, 1, 0).float()
|
| 267 |
+
y_pred = torch.where(y_pred>=threshold, 1, 0).float()
|
| 268 |
+
y = pool(y)
|
| 269 |
+
y_pred = pool(y_pred)
|
| 270 |
+
mat = torchmetrics.functional.confusion_matrix(y_pred, y, task='binary', threshold=threshold)
|
| 271 |
+
(tn, fp), (fn, tp) = to_cpu_tensor(mat)
|
| 272 |
+
return tp, tn, fp, fn
|
| 273 |
+
|
| 274 |
+
def csi(tp, tn, fp, fn):
|
| 275 |
+
'''Critical Success Index. The larger the better.'''
|
| 276 |
+
if (tp + fn + fp) < 1e-7:
|
| 277 |
+
return 0.
|
| 278 |
+
return tp / (tp + fn + fp)
|
| 279 |
+
|
| 280 |
+
def hss(tp, tn, fp, fn):
|
| 281 |
+
'''Heidke Skill Score. (-inf, 1]. Larger better.'''
|
| 282 |
+
if (tp+fn)*(fn+tn) + (tp+fp)*(fp+tn) == 0:
|
| 283 |
+
return 0.
|
| 284 |
+
return 2 * (tp*tn - fp*fn) / ((tp+fn)*(fn+tn) + (tp+fp)*(fp+tn))
|
| 285 |
+
|
| 286 |
+
# =======================================================================
|
| 287 |
+
# Data Visualization
|
| 288 |
+
# =======================================================================
|
| 289 |
+
|
| 290 |
+
def torch_visualize(sequences, savedir=None, horizontal=10, vmin=0, vmax=1):
|
| 291 |
+
'''
|
| 292 |
+
input: sequences, a list/dict of numpy/torch arrays with shape (B, T, C, H, W)
|
| 293 |
+
C is assumed to be 1 and squeezed
|
| 294 |
+
If batch > 1, only the first sequence will be printed
|
| 295 |
+
'''
|
| 296 |
+
# First pass: compute the vertical height and convert to proper format
|
| 297 |
+
vertical = 0
|
| 298 |
+
display_texts = []
|
| 299 |
+
if (type(sequences) is dict):
|
| 300 |
+
temp = []
|
| 301 |
+
for k, v in sequences.items():
|
| 302 |
+
vertical += int(np.ceil(v.shape[1] / horizontal))
|
| 303 |
+
temp.append(v)
|
| 304 |
+
display_texts.append(k)
|
| 305 |
+
sequences = temp
|
| 306 |
+
else:
|
| 307 |
+
for i, sequence in enumerate(sequences):
|
| 308 |
+
vertical += int(np.ceil(sequence.shape[1] / horizontal))
|
| 309 |
+
display_texts.append(f'Item {i+1}')
|
| 310 |
+
sequences = to_cpu_tensor(*sequences)
|
| 311 |
+
# Plot the sequences
|
| 312 |
+
j = 0
|
| 313 |
+
fig, axes = plt.subplots(vertical, horizontal, figsize=(2*horizontal, 2*vertical), tight_layout=True)
|
| 314 |
+
plt.setp(axes, xticks=[], yticks=[])
|
| 315 |
+
for k, sequence in enumerate(sequences):
|
| 316 |
+
# only take the first batch, now seq[0] is the temporal dim
|
| 317 |
+
sequence = sequence[0].squeeze() # (T, H, W)
|
| 318 |
+
axes[j, 0].set_ylabel(display_texts[k])
|
| 319 |
+
for i, frame in enumerate(sequence):
|
| 320 |
+
j_shift = j + i // horizontal
|
| 321 |
+
i_shift = i % horizontal
|
| 322 |
+
axes[j_shift, i_shift].imshow(frame, vmin=vmin, vmax=vmax, cmap='gray')
|
| 323 |
+
j += int(np.ceil(sequence.shape[0] / horizontal))
|
| 324 |
+
if savedir:
|
| 325 |
+
plt.savefig(savedir + '' if savedir.endswith('.png') else '.png')
|
| 326 |
+
plt.close()
|
| 327 |
+
else:
|
| 328 |
+
plt.show()
|
| 329 |
+
|
| 330 |
+
""" Visualize function with colorbar and a line seprate input and output """
|
| 331 |
+
def color_visualize(sequences, savedir='', horizontal=5, skip=1, ypos=0):
|
| 332 |
+
'''
|
| 333 |
+
input: sequences, a list/dict of numpy/torch arrays with shape (B, T, C, H, W)
|
| 334 |
+
C is assumed to be 1 and squeezed
|
| 335 |
+
If batch > 1, only the first sequence will be printed
|
| 336 |
+
'''
|
| 337 |
+
plt.style.use(['science', 'no-latex'])
|
| 338 |
+
VIL_COLORS = [[0, 0, 0],
|
| 339 |
+
[0.30196078431372547, 0.30196078431372547, 0.30196078431372547],
|
| 340 |
+
[0.1568627450980392, 0.7450980392156863, 0.1568627450980392],
|
| 341 |
+
[0.09803921568627451, 0.5882352941176471, 0.09803921568627451],
|
| 342 |
+
[0.0392156862745098, 0.4117647058823529, 0.0392156862745098],
|
| 343 |
+
[0.0392156862745098, 0.29411764705882354, 0.0392156862745098],
|
| 344 |
+
[0.9607843137254902, 0.9607843137254902, 0.0],
|
| 345 |
+
[0.9294117647058824, 0.6745098039215687, 0.0],
|
| 346 |
+
[0.9411764705882353, 0.43137254901960786, 0.0],
|
| 347 |
+
[0.6274509803921569, 0.0, 0.0],
|
| 348 |
+
[0.9058823529411765, 0.0, 1.0]]
|
| 349 |
+
|
| 350 |
+
VIL_LEVELS = [0.0, 16.0, 31.0, 59.0, 74.0, 100.0, 133.0, 160.0, 181.0, 219.0, 255.0]
|
| 351 |
+
|
| 352 |
+
# First pass: compute the vertical height and convert to proper format
|
| 353 |
+
vertical = 0
|
| 354 |
+
display_texts = []
|
| 355 |
+
if (type(sequences) is dict):
|
| 356 |
+
temp = []
|
| 357 |
+
for k, v in sequences.items():
|
| 358 |
+
vertical += int(np.ceil(v.shape[1] / horizontal))
|
| 359 |
+
temp.append(v)
|
| 360 |
+
display_texts.append(k)
|
| 361 |
+
sequences = temp
|
| 362 |
+
else:
|
| 363 |
+
for i, sequence in enumerate(sequences):
|
| 364 |
+
vertical += int(np.ceil(sequence.shape[1] / horizontal))
|
| 365 |
+
display_texts.append(f'Item {i+1}')
|
| 366 |
+
sequences = to_cpu_tensor(*sequences)
|
| 367 |
+
# Plot the sequences
|
| 368 |
+
j = 0
|
| 369 |
+
fig, axes = plt.subplots(vertical, horizontal, figsize=(2*horizontal, 2*vertical), tight_layout=True)
|
| 370 |
+
plt.subplots_adjust(hspace=0.0, wspace=0.0) # tight layout
|
| 371 |
+
plt.setp(axes, xticks=[], yticks=[])
|
| 372 |
+
for k, sequence in enumerate(sequences):
|
| 373 |
+
# only take the first batch, now seq[0] is the temporal dim
|
| 374 |
+
sequence = sequence[0].squeeze() # (T, H, W)
|
| 375 |
+
|
| 376 |
+
## =================
|
| 377 |
+
# = labels of time =
|
| 378 |
+
if k == 0:
|
| 379 |
+
for i in range(len(sequence)):
|
| 380 |
+
axes[j, i].set_xlabel(f'$t-{(len(sequence)-i)-1}$', fontsize=16)
|
| 381 |
+
axes[j, i].xaxis.set_label_position('top')
|
| 382 |
+
elif k == len(sequences)-1:
|
| 383 |
+
for i in range(len(sequence)):
|
| 384 |
+
axes[j, i].set_xlabel(f'$t+{skip*i+1}$', fontsize=16)
|
| 385 |
+
axes[j, i].xaxis.set_label_position('bottom')
|
| 386 |
+
## =================
|
| 387 |
+
axes[j, 0].set_ylabel(display_texts[k], fontsize=16)
|
| 388 |
+
for i, frame in enumerate(sequence):
|
| 389 |
+
j_shift = j + i // horizontal
|
| 390 |
+
i_shift = i % horizontal
|
| 391 |
+
im = axes[j_shift, i_shift].imshow(frame*255, cmap=ListedColormap(VIL_COLORS), \
|
| 392 |
+
norm=BoundaryNorm(VIL_LEVELS, ListedColormap(VIL_COLORS).N))
|
| 393 |
+
j += int(np.ceil(sequence.shape[0] / horizontal))
|
| 394 |
+
|
| 395 |
+
## = plot splittin line =
|
| 396 |
+
if ypos == 0:
|
| 397 |
+
ypos = 1 - 1 / len(sequences) - 0.017
|
| 398 |
+
fig.lines.append(Line2D((0, 1), (ypos, ypos), transform=fig.transFigure, ls='--', linewidth=2, color='#444'))
|
| 399 |
+
# color bar
|
| 400 |
+
cax = fig.add_axes([1, 0.05, 0.02, 0.5])
|
| 401 |
+
fig.colorbar(im, cax=cax)
|
| 402 |
+
## =================
|
| 403 |
+
if savedir:
|
| 404 |
+
plt.savefig(savedir + '' if len(savedir)>0 else 'out.png')
|
| 405 |
+
plt.close()
|
| 406 |
+
else:
|
| 407 |
+
plt.show()
|
| 408 |
+
|
| 409 |
+
from tempfile import NamedTemporaryFile
|
| 410 |
+
|
| 411 |
+
""" Visualize function with colorbar and a line seprate input and output """
|
| 412 |
+
def gradio_visualize(sequences, horizontal=5, skip=1, ypos=0):
|
| 413 |
+
'''
|
| 414 |
+
input: sequences, a list/dict of numpy/torch arrays with shape (B, T, C, H, W)
|
| 415 |
+
C is assumed to be 1 and squeezed
|
| 416 |
+
If batch > 1, only the first sequence will be printed
|
| 417 |
+
'''
|
| 418 |
+
plt.style.use(['science', 'no-latex'])
|
| 419 |
+
VIL_COLORS = [[0, 0, 0],
|
| 420 |
+
[0.30196078431372547, 0.30196078431372547, 0.30196078431372547],
|
| 421 |
+
[0.1568627450980392, 0.7450980392156863, 0.1568627450980392],
|
| 422 |
+
[0.09803921568627451, 0.5882352941176471, 0.09803921568627451],
|
| 423 |
+
[0.0392156862745098, 0.4117647058823529, 0.0392156862745098],
|
| 424 |
+
[0.0392156862745098, 0.29411764705882354, 0.0392156862745098],
|
| 425 |
+
[0.9607843137254902, 0.9607843137254902, 0.0],
|
| 426 |
+
[0.9294117647058824, 0.6745098039215687, 0.0],
|
| 427 |
+
[0.9411764705882353, 0.43137254901960786, 0.0],
|
| 428 |
+
[0.6274509803921569, 0.0, 0.0],
|
| 429 |
+
[0.9058823529411765, 0.0, 1.0]]
|
| 430 |
+
|
| 431 |
+
VIL_LEVELS = [0.0, 16.0, 31.0, 59.0, 74.0, 100.0, 133.0, 160.0, 181.0, 219.0, 255.0]
|
| 432 |
+
|
| 433 |
+
# First pass: compute the vertical height and convert to proper format
|
| 434 |
+
vertical = 0
|
| 435 |
+
display_texts = []
|
| 436 |
+
if (type(sequences) is dict):
|
| 437 |
+
temp = []
|
| 438 |
+
for k, v in sequences.items():
|
| 439 |
+
vertical += int(np.ceil(v.shape[1] / horizontal))
|
| 440 |
+
temp.append(v)
|
| 441 |
+
display_texts.append(k)
|
| 442 |
+
sequences = temp
|
| 443 |
+
else:
|
| 444 |
+
for i, sequence in enumerate(sequences):
|
| 445 |
+
vertical += int(np.ceil(sequence.shape[1] / horizontal))
|
| 446 |
+
display_texts.append(f'Item {i+1}')
|
| 447 |
+
sequences = to_cpu_tensor(*sequences)
|
| 448 |
+
# Plot the sequences
|
| 449 |
+
j = 0
|
| 450 |
+
fig, axes = plt.subplots(vertical, horizontal, figsize=(2*horizontal, 2*vertical), tight_layout=True)
|
| 451 |
+
plt.subplots_adjust(hspace=0.0, wspace=0.0) # tight layout
|
| 452 |
+
plt.setp(axes, xticks=[], yticks=[])
|
| 453 |
+
for k, sequence in enumerate(sequences):
|
| 454 |
+
# only take the first batch, now seq[0] is the temporal dim
|
| 455 |
+
sequence = sequence.squeeze() # (T, H, W)
|
| 456 |
+
|
| 457 |
+
## =================
|
| 458 |
+
# = labels of time =
|
| 459 |
+
if k == 0:
|
| 460 |
+
for i in range(len(sequence)):
|
| 461 |
+
axes[j, i].set_xlabel(f'$t-{(len(sequence)-i)-1}$', fontsize=16)
|
| 462 |
+
axes[j, i].xaxis.set_label_position('top')
|
| 463 |
+
elif k == len(sequences)-1:
|
| 464 |
+
for i in range(len(sequence)):
|
| 465 |
+
axes[j, i].set_xlabel(f'$t+{skip*i+1}$', fontsize=16)
|
| 466 |
+
axes[j, i].xaxis.set_label_position('bottom')
|
| 467 |
+
## =================
|
| 468 |
+
axes[j, 0].set_ylabel(display_texts[k], fontsize=16)
|
| 469 |
+
for i, frame in enumerate(sequence):
|
| 470 |
+
j_shift = j + i // horizontal
|
| 471 |
+
i_shift = i % horizontal
|
| 472 |
+
im = axes[j_shift, i_shift].imshow(frame*255, cmap=ListedColormap(VIL_COLORS), \
|
| 473 |
+
norm=BoundaryNorm(VIL_LEVELS, ListedColormap(VIL_COLORS).N))
|
| 474 |
+
j += int(np.ceil(sequence.shape[0] / horizontal))
|
| 475 |
+
|
| 476 |
+
## = plot splittin line =
|
| 477 |
+
if ypos == 0:
|
| 478 |
+
ypos = 1 - 1 / len(sequences) - 0.017
|
| 479 |
+
fig.lines.append(Line2D((0, 1), (ypos, ypos), transform=fig.transFigure, ls='--', linewidth=2, color='#444'))
|
| 480 |
+
# color bar
|
| 481 |
+
cax = fig.add_axes([1, 0.05, 0.02, 0.5])
|
| 482 |
+
fig.colorbar(im, cax=cax)
|
| 483 |
+
|
| 484 |
+
# Save the figure to a temporary file
|
| 485 |
+
with NamedTemporaryFile(suffix=".png", delete=False) as ff:
|
| 486 |
+
fig.savefig(ff.name)
|
| 487 |
+
file_path = ff.name
|
| 488 |
+
|
| 489 |
+
# It's important to close the figure to prevent memory leaks
|
| 490 |
+
plt.close(fig)
|
| 491 |
+
|
| 492 |
+
return file_path
|
| 493 |
+
|
| 494 |
+
import matplotlib.animation as animation
|
| 495 |
+
|
| 496 |
+
def gradio_gif(sequences, T):
|
| 497 |
+
'''
|
| 498 |
+
input: sequences, a list/dict of numpy/torch arrays with shape (B, T, C, H, W)
|
| 499 |
+
C is assumed to be 1 and squeezed
|
| 500 |
+
If batch > 1, only the first sequence will be printed
|
| 501 |
+
'''
|
| 502 |
+
plt.style.use(['science', 'no-latex'])
|
| 503 |
+
VIL_COLORS = [[0, 0, 0],
|
| 504 |
+
[0.30196078431372547, 0.30196078431372547, 0.30196078431372547],
|
| 505 |
+
[0.1568627450980392, 0.7450980392156863, 0.1568627450980392],
|
| 506 |
+
[0.09803921568627451, 0.5882352941176471, 0.09803921568627451],
|
| 507 |
+
[0.0392156862745098, 0.4117647058823529, 0.0392156862745098],
|
| 508 |
+
[0.0392156862745098, 0.29411764705882354, 0.0392156862745098],
|
| 509 |
+
[0.9607843137254902, 0.9607843137254902, 0.0],
|
| 510 |
+
[0.9294117647058824, 0.6745098039215687, 0.0],
|
| 511 |
+
[0.9411764705882353, 0.43137254901960786, 0.0],
|
| 512 |
+
[0.6274509803921569, 0.0, 0.0],
|
| 513 |
+
[0.9058823529411765, 0.0, 1.0]]
|
| 514 |
+
|
| 515 |
+
VIL_LEVELS = [0.0, 16.0, 31.0, 59.0, 74.0, 100.0, 133.0, 160.0, 181.0, 219.0, 255.0]
|
| 516 |
+
|
| 517 |
+
horizontal = len(sequences)
|
| 518 |
+
fig_size = 3
|
| 519 |
+
fig, axes = plt.subplots(nrows=1, ncols=horizontal, figsize=(fig_size*horizontal, fig_size), tight_layout=True)
|
| 520 |
+
plt.subplots_adjust(hspace=0.0, wspace=0.0) # tight layout
|
| 521 |
+
plt.setp(axes, xticks=[], yticks=[])
|
| 522 |
+
for i, sequence in enumerate(sequences.values()):
|
| 523 |
+
axes[i].set_xticks([])
|
| 524 |
+
axes[i].set_yticks([])
|
| 525 |
+
axes[i].set_xlabel(f'Ensemble {i+1}', fontsize=12)
|
| 526 |
+
frame = sequence[0].squeeze()
|
| 527 |
+
im = axes[i].imshow(frame*255, cmap=ListedColormap(VIL_COLORS), \
|
| 528 |
+
norm=BoundaryNorm(VIL_LEVELS, ListedColormap(VIL_COLORS).N), animated=True)
|
| 529 |
+
|
| 530 |
+
title = fig.suptitle('', y=0.9, x=0.505, fontsize=16) # Initialize an empty super title
|
| 531 |
+
|
| 532 |
+
fig.colorbar(im)
|
| 533 |
+
|
| 534 |
+
def animate(t):
|
| 535 |
+
for i, sequence in enumerate(sequences.values()):
|
| 536 |
+
frame = sequence[t].squeeze()
|
| 537 |
+
im = axes[i].imshow(frame*255, cmap=ListedColormap(VIL_COLORS), \
|
| 538 |
+
norm=BoundaryNorm(VIL_LEVELS, ListedColormap(VIL_COLORS).N), animated=True)
|
| 539 |
+
plt.subplots_adjust(hspace=0.0, wspace=0.0) # tight layout
|
| 540 |
+
|
| 541 |
+
title.set_text(f'$t + {t}$') # update the title text
|
| 542 |
+
|
| 543 |
+
return fig,
|
| 544 |
+
|
| 545 |
+
ani = animation.FuncAnimation(fig, animate, frames=T, interval=750, blit=True, repeat_delay=50,)
|
| 546 |
+
|
| 547 |
+
# Save the figure to a temporary file
|
| 548 |
+
with NamedTemporaryFile(suffix=".gif", delete=False) as ff:
|
| 549 |
+
ani.save(ff.name, writer='pillow', fps=5)
|
| 550 |
+
file_path = ff.name
|
| 551 |
+
|
| 552 |
+
plt.close()
|
| 553 |
+
return file_path
|
| 554 |
+
|
| 555 |
+
# import matplotlib.pyplot as plt
|
| 556 |
+
# import matplotlib.animation as animation
|
| 557 |
+
# def make_gif(frames, save_path):
|
| 558 |
+
# fig, ax = plt.subplots(figsize=(4,4))
|
| 559 |
+
# im = ax.imshow(frames[0].squeeze(), cmap='gray', vmin=0, vmax=1, animated=True)
|
| 560 |
+
# ax.set_axis_off()
|
| 561 |
+
|
| 562 |
+
# def update(i):
|
| 563 |
+
# im.set_array(frames[i].squeeze())
|
| 564 |
+
# return im,
|
| 565 |
+
# animation_fig =
|
| 566 |
+
# animation_fig.save(f"./{save_path}.gif")
|