Add files using upload-large-folder tool
Browse files- Rectified_Noise/GVP-Disp/W_No.log +5 -0
- Rectified_Noise/GVP-Disp/W_True_0.15.log +5 -0
- Rectified_Noise/GVP-Disp/W_True_0.5.log +5 -0
- Rectified_Noise/GVP-Disp/download.py +41 -0
- Rectified_Noise/GVP-Disp/environment.yml +16 -0
- Rectified_Noise/GVP-Disp/models.py +647 -0
- Rectified_Noise/GVP-Disp/sample_ddp.py +233 -0
- Rectified_Noise/GVP-Disp/sample_rectified_noise.py +380 -0
- Rectified_Noise/GVP-Disp/train_utils.py +35 -0
- Rectified_Noise/GVP-Disp/w_training1_VP.log +628 -0
- Rectified_Noise/GVP-Disp/权重类型分析.md +133 -0
- Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000032.png +0 -0
- Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000077.png +0 -0
- Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000133.png +0 -0
- Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000161.png +0 -0
- Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000220.png +0 -0
- Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000331.png +0 -0
- Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000387.png +0 -0
- Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000505.png +0 -0
- Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000517.png +0 -0
- Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000551.png +0 -0
- Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000726.png +0 -0
- Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000817.png +0 -0
- Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000865.png +0 -0
- Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000914.png +0 -0
- Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000940.png +0 -0
- Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/001043.png +0 -0
- Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/001210.png +0 -0
- SiT_back/SiT_clean/W_training.log +110 -0
- SiT_back/SiT_clean/__pycache__/download.cpython-312.pyc +0 -0
- SiT_back/SiT_clean/__pycache__/models.cpython-312.pyc +0 -0
- SiT_back/SiT_clean/__pycache__/train_utils.cpython-312.pyc +0 -0
- SiT_back/SiT_clean/download.py +40 -0
- SiT_back/SiT_clean/models.py +370 -0
- SiT_back/SiT_clean/run.sh +0 -0
- SiT_back/SiT_clean/sample.py +144 -0
- SiT_back/SiT_clean/sample_ddp.py +233 -0
- SiT_back/SiT_clean/train.py +371 -0
- SiT_back/SiT_clean/train_utils.py +32 -0
- SiT_back/SiT_clean/transport/__init__.py +65 -0
- SiT_back/SiT_clean/transport/__pycache__/__init__.cpython-312.pyc +0 -0
- SiT_back/SiT_clean/transport/__pycache__/integrators.cpython-312.pyc +0 -0
- SiT_back/SiT_clean/transport/__pycache__/path.cpython-312.pyc +0 -0
- SiT_back/SiT_clean/transport/__pycache__/transport.cpython-312.pyc +0 -0
- SiT_back/SiT_clean/transport/__pycache__/utils.cpython-312.pyc +0 -0
- SiT_back/SiT_clean/transport/integrators.py +115 -0
- SiT_back/SiT_clean/transport/path.py +192 -0
- SiT_back/SiT_clean/transport/transport.py +440 -0
- SiT_back/SiT_clean/transport/utils.py +29 -0
- SiT_back/SiT_clean/wandb_utils.py +55 -0
Rectified_Noise/GVP-Disp/W_No.log
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 0 |
0%| | 0/47 [00:00<?, ?it/s]
|
| 1 |
2%|▏ | 1/47 [02:10<1:39:52, 130.26s/it]
|
| 2 |
4%|▍ | 2/47 [04:19<1:37:18, 129.75s/it]
|
| 3 |
6%|▋ | 3/47 [06:29<1:35:03, 129.63s/it]
|
| 4 |
9%|▊ | 4/47 [08:38<1:32:51, 129.58s/it]
|
| 5 |
11%|█ | 5/47 [10:48<1:30:41, 129.55s/it]
|
| 6 |
13%|█▎ | 6/47 [12:57<1:28:31, 129.54s/it]
|
| 7 |
15%|█▍ | 7/47 [15:07<1:26:22, 129.56s/it]
|
| 8 |
17%|█▋ | 8/47 [17:14<1:23:44, 128.82s/it]
|
| 9 |
19%|█▉ | 9/47 [19:20<1:21:01, 127.93s/it]
|
| 10 |
21%|██▏ | 10/47 [21:29<1:19:01, 128.15s/it]
|
| 11 |
23%|██▎ | 11/47 [23:38<1:17:05, 128.47s/it]
|
| 12 |
26%|██▌ | 12/47 [25:47<1:15:03, 128.67s/it]
|
| 13 |
28%|██▊ | 13/47 [27:56<1:13:02, 128.91s/it]
|
|
|
|
| 1 |
+
[NOTICE] The application is pending for GPU resource in asynchronous queue. The longest waiting time in queue is 1800 seconds.
|
| 2 |
+
Starting rank=0, seed=0, world_size=1.
|
| 3 |
+
Saving .png samples at GVP_samples/depth-mu-2-threshold-0.0-0025000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04
|
| 4 |
+
Total number of images that will be sampled: 3008
|
| 5 |
+
|
| 6 |
0%| | 0/47 [00:00<?, ?it/s]
|
| 7 |
2%|▏ | 1/47 [02:10<1:39:52, 130.26s/it]
|
| 8 |
4%|▍ | 2/47 [04:19<1:37:18, 129.75s/it]
|
| 9 |
6%|▋ | 3/47 [06:29<1:35:03, 129.63s/it]
|
| 10 |
9%|▊ | 4/47 [08:38<1:32:51, 129.58s/it]
|
| 11 |
11%|█ | 5/47 [10:48<1:30:41, 129.55s/it]
|
| 12 |
13%|█▎ | 6/47 [12:57<1:28:31, 129.54s/it]
|
| 13 |
15%|█▍ | 7/47 [15:07<1:26:22, 129.56s/it]
|
| 14 |
17%|█▋ | 8/47 [17:14<1:23:44, 128.82s/it]
|
| 15 |
19%|█▉ | 9/47 [19:20<1:21:01, 127.93s/it]
|
| 16 |
21%|██▏ | 10/47 [21:29<1:19:01, 128.15s/it]
|
| 17 |
23%|██▎ | 11/47 [23:38<1:17:05, 128.47s/it]
|
| 18 |
26%|██▌ | 12/47 [25:47<1:15:03, 128.67s/it]
|
| 19 |
28%|██▊ | 13/47 [27:56<1:13:02, 128.91s/it]
|
Rectified_Noise/GVP-Disp/W_True_0.15.log
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 0 |
0%| | 0/47 [00:00<?, ?it/s]
|
| 1 |
2%|▏ | 1/47 [01:34<1:12:45, 94.90s/it]
|
| 2 |
4%|▍ | 2/47 [03:08<1:10:43, 94.31s/it]
|
| 3 |
6%|▋ | 3/47 [04:42<1:09:03, 94.17s/it]
|
| 4 |
9%|▊ | 4/47 [06:16<1:07:23, 94.04s/it]
|
| 5 |
11%|█ | 5/47 [07:50<1:05:47, 93.99s/it]
|
| 6 |
13%|█▎ | 6/47 [09:24<1:04:09, 93.88s/it]
|
| 7 |
15%|█▍ | 7/47 [10:58<1:02:34, 93.85s/it]
|
| 8 |
17%|█▋ | 8/47 [12:31<1:00:57, 93.79s/it]
|
| 9 |
19%|█▉ | 9/47 [14:05<59:23, 93.77s/it]
|
| 10 |
21%|██▏ | 10/47 [15:39<57:48, 93.75s/it]
|
| 11 |
23%|██▎ | 11/47 [17:12<56:15, 93.76s/it]
|
| 12 |
26%|██▌ | 12/47 [18:46<54:41, 93.76s/it]
|
| 13 |
28%|██▊ | 13/47 [20:20<53:08, 93.78s/it]
|
| 14 |
30%|██▉ | 14/47 [21:54<51:36, 93.82s/it]
|
| 15 |
32%|███▏ | 15/47 [23:28<50:01, 93.80s/it]
|
| 16 |
34%|███▍ | 16/47 [25:01<48:25, 93.71s/it]
|
| 17 |
36%|███▌ | 17/47 [26:35<46:53, 93.77s/it]
|
| 18 |
38%|███▊ | 18/47 [28:09<45:19, 93.78s/it]
|
| 19 |
40%|████ | 19/47 [29:43<43:45, 93.77s/it]
|
|
|
|
| 1 |
+
[NOTICE] The application is pending for GPU resource in asynchronous queue. The longest waiting time in queue is 1800 seconds.
|
| 2 |
+
Starting rank=0, seed=0, world_size=1.
|
| 3 |
+
Saving .png samples at GVP_samples/depth-mu-2-threshold-0.15-0025000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04
|
| 4 |
+
Total number of images that will be sampled: 3008
|
| 5 |
+
|
| 6 |
0%| | 0/47 [00:00<?, ?it/s]
|
| 7 |
2%|▏ | 1/47 [01:34<1:12:45, 94.90s/it]
|
| 8 |
4%|▍ | 2/47 [03:08<1:10:43, 94.31s/it]
|
| 9 |
6%|▋ | 3/47 [04:42<1:09:03, 94.17s/it]
|
| 10 |
9%|▊ | 4/47 [06:16<1:07:23, 94.04s/it]
|
| 11 |
11%|█ | 5/47 [07:50<1:05:47, 93.99s/it]
|
| 12 |
13%|█▎ | 6/47 [09:24<1:04:09, 93.88s/it]
|
| 13 |
15%|█▍ | 7/47 [10:58<1:02:34, 93.85s/it]
|
| 14 |
17%|█▋ | 8/47 [12:31<1:00:57, 93.79s/it]
|
| 15 |
19%|█▉ | 9/47 [14:05<59:23, 93.77s/it]
|
| 16 |
21%|██▏ | 10/47 [15:39<57:48, 93.75s/it]
|
| 17 |
23%|██▎ | 11/47 [17:12<56:15, 93.76s/it]
|
| 18 |
26%|██▌ | 12/47 [18:46<54:41, 93.76s/it]
|
| 19 |
28%|██▊ | 13/47 [20:20<53:08, 93.78s/it]
|
| 20 |
30%|██▉ | 14/47 [21:54<51:36, 93.82s/it]
|
| 21 |
32%|███▏ | 15/47 [23:28<50:01, 93.80s/it]
|
| 22 |
34%|███▍ | 16/47 [25:01<48:25, 93.71s/it]
|
| 23 |
36%|███▌ | 17/47 [26:35<46:53, 93.77s/it]
|
| 24 |
38%|███▊ | 18/47 [28:09<45:19, 93.78s/it]
|
| 25 |
40%|████ | 19/47 [29:43<43:45, 93.77s/it]
|
Rectified_Noise/GVP-Disp/W_True_0.5.log
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 0 |
0%| | 0/47 [00:00<?, ?it/s]
|
| 1 |
2%|▏ | 1/47 [01:34<1:12:36, 94.71s/it]
|
| 2 |
4%|▍ | 2/47 [03:08<1:10:34, 94.11s/it]
|
| 3 |
6%|▋ | 3/47 [04:42<1:08:54, 93.97s/it]
|
| 4 |
9%|▊ | 4/47 [06:15<1:07:14, 93.82s/it]
|
| 5 |
11%|█ | 5/47 [07:49<1:05:39, 93.79s/it]
|
| 6 |
13%|█▎ | 6/47 [09:23<1:04:03, 93.75s/it]
|
| 7 |
15%|█▍ | 7/47 [10:57<1:02:31, 93.79s/it]
|
| 8 |
17%|█▋ | 8/47 [12:30<1:00:56, 93.75s/it]
|
| 9 |
19%|█▉ | 9/47 [14:04<59:21, 93.73s/it]
|
| 10 |
21%|██▏ | 10/47 [15:38<57:47, 93.73s/it]
|
| 11 |
23%|██▎ | 11/47 [17:11<56:14, 93.74s/it]
|
| 12 |
26%|██▌ | 12/47 [18:45<54:40, 93.74s/it]
|
| 13 |
28%|██▊ | 13/47 [20:19<53:06, 93.72s/it]
|
| 14 |
30%|██▉ | 14/47 [21:53<51:34, 93.77s/it]
|
| 15 |
32%|███▏ | 15/47 [23:26<50:00, 93.77s/it]
|
| 16 |
34%|███▍ | 16/47 [25:00<48:24, 93.70s/it]
|
| 17 |
36%|███▌ | 17/47 [26:34<46:53, 93.77s/it]
|
| 18 |
38%|███▊ | 18/47 [28:08<45:19, 93.79s/it]
|
| 19 |
40%|████ | 19/47 [29:42<43:45, 93.78s/it]
|
|
|
|
| 1 |
+
[NOTICE] The application is pending for GPU resource in asynchronous queue. The longest waiting time in queue is 1800 seconds.
|
| 2 |
+
Starting rank=0, seed=0, world_size=1.
|
| 3 |
+
Saving .png samples at GVP_samples/depth-mu-2-threshold-0.5-0025000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04
|
| 4 |
+
Total number of images that will be sampled: 3008
|
| 5 |
+
|
| 6 |
0%| | 0/47 [00:00<?, ?it/s]
|
| 7 |
2%|▏ | 1/47 [01:34<1:12:36, 94.71s/it]
|
| 8 |
4%|▍ | 2/47 [03:08<1:10:34, 94.11s/it]
|
| 9 |
6%|▋ | 3/47 [04:42<1:08:54, 93.97s/it]
|
| 10 |
9%|▊ | 4/47 [06:15<1:07:14, 93.82s/it]
|
| 11 |
11%|█ | 5/47 [07:49<1:05:39, 93.79s/it]
|
| 12 |
13%|█▎ | 6/47 [09:23<1:04:03, 93.75s/it]
|
| 13 |
15%|█▍ | 7/47 [10:57<1:02:31, 93.79s/it]
|
| 14 |
17%|█▋ | 8/47 [12:30<1:00:56, 93.75s/it]
|
| 15 |
19%|█▉ | 9/47 [14:04<59:21, 93.73s/it]
|
| 16 |
21%|██▏ | 10/47 [15:38<57:47, 93.73s/it]
|
| 17 |
23%|██▎ | 11/47 [17:11<56:14, 93.74s/it]
|
| 18 |
26%|██▌ | 12/47 [18:45<54:40, 93.74s/it]
|
| 19 |
28%|██▊ | 13/47 [20:19<53:06, 93.72s/it]
|
| 20 |
30%|██▉ | 14/47 [21:53<51:34, 93.77s/it]
|
| 21 |
32%|███▏ | 15/47 [23:26<50:00, 93.77s/it]
|
| 22 |
34%|███▍ | 16/47 [25:00<48:24, 93.70s/it]
|
| 23 |
36%|███▌ | 17/47 [26:34<46:53, 93.77s/it]
|
| 24 |
38%|███▊ | 18/47 [28:08<45:19, 93.79s/it]
|
| 25 |
40%|████ | 19/47 [29:42<43:45, 93.78s/it]
|
Rectified_Noise/GVP-Disp/download.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This source code is licensed under the license found in the
|
| 2 |
+
# LICENSE file in the root directory of this source tree.
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
Functions for downloading pre-trained SiT models
|
| 6 |
+
"""
|
| 7 |
+
from torchvision.datasets.utils import download_url
|
| 8 |
+
import torch
|
| 9 |
+
import os
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
pretrained_models = {'SiT-XL-2-256x256.pt'}
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def find_model(model_name):
|
| 16 |
+
"""
|
| 17 |
+
Finds a pre-trained SiT model, downloading it if necessary. Alternatively, loads a model from a local path.
|
| 18 |
+
"""
|
| 19 |
+
if model_name in pretrained_models:
|
| 20 |
+
return download_model(model_name)
|
| 21 |
+
else:
|
| 22 |
+
assert os.path.isfile(model_name), f'Could not find SiT checkpoint at {model_name}'
|
| 23 |
+
checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage, weights_only=False)
|
| 24 |
+
if "ema" in checkpoint: # supports checkpoints from train.py
|
| 25 |
+
checkpoint = checkpoint["ema"]
|
| 26 |
+
return checkpoint
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def download_model(model_name):
|
| 30 |
+
"""
|
| 31 |
+
Downloads a pre-trained SiT model from the web.
|
| 32 |
+
"""
|
| 33 |
+
assert model_name in pretrained_models
|
| 34 |
+
local_path = f'pretrained_models/{model_name}'
|
| 35 |
+
if not os.path.isfile(local_path):
|
| 36 |
+
os.makedirs('pretrained_models', exist_ok=True)
|
| 37 |
+
web_path = f'https://www.dl.dropboxusercontent.com/scl/fi/as9oeomcbub47de5g4be0/SiT-XL-2-256.pt?rlkey=uxzxmpicu46coq3msb17b9ofa&dl=0'
|
| 38 |
+
download_url(web_path, 'pretrained_models', filename=model_name)
|
| 39 |
+
model = torch.load(local_path, map_location=lambda storage, loc: storage, weights_only=False)
|
| 40 |
+
return model
|
| 41 |
+
|
Rectified_Noise/GVP-Disp/environment.yml
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: RN
|
| 2 |
+
channels:
|
| 3 |
+
- pytorch
|
| 4 |
+
- nvidia
|
| 5 |
+
dependencies:
|
| 6 |
+
- python >= 3.8
|
| 7 |
+
- pytorch >= 1.13
|
| 8 |
+
- torchvision
|
| 9 |
+
- pytorch-cuda >=11.7
|
| 10 |
+
- pip
|
| 11 |
+
- pip:
|
| 12 |
+
- timm
|
| 13 |
+
- diffusers
|
| 14 |
+
- accelerate
|
| 15 |
+
- torchdiffeq
|
| 16 |
+
- wandb
|
Rectified_Noise/GVP-Disp/models.py
ADDED
|
@@ -0,0 +1,647 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This source code is licensed under the license found in the
|
| 2 |
+
# LICENSE file in the root directory of this source tree.
|
| 3 |
+
# --------------------------------------------------------
|
| 4 |
+
# References:
|
| 5 |
+
# GLIDE: https://github.com/openai/glide-text2im
|
| 6 |
+
# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
|
| 7 |
+
# --------------------------------------------------------
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import numpy as np
|
| 12 |
+
import math
|
| 13 |
+
from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def modulate(x, shift, scale):
|
| 17 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
#################################################################################
|
| 21 |
+
# Embedding Layers for Timesteps and Class Labels #
|
| 22 |
+
#################################################################################
|
| 23 |
+
|
| 24 |
+
class TimestepEmbedder(nn.Module):
|
| 25 |
+
"""
|
| 26 |
+
Embeds scalar timesteps into vector representations.
|
| 27 |
+
"""
|
| 28 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.mlp = nn.Sequential(
|
| 31 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
| 32 |
+
nn.SiLU(),
|
| 33 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
| 34 |
+
)
|
| 35 |
+
self.frequency_embedding_size = frequency_embedding_size
|
| 36 |
+
|
| 37 |
+
@staticmethod
|
| 38 |
+
def timestep_embedding(t, dim, max_period=10000):
|
| 39 |
+
"""
|
| 40 |
+
Create sinusoidal timestep embeddings.
|
| 41 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
| 42 |
+
These may be fractional.
|
| 43 |
+
:param dim: the dimension of the output.
|
| 44 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
| 45 |
+
:return: an (N, D) Tensor of positional embeddings.
|
| 46 |
+
"""
|
| 47 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
| 48 |
+
half = dim // 2
|
| 49 |
+
freqs = torch.exp(
|
| 50 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
| 51 |
+
).to(device=t.device)
|
| 52 |
+
args = t[:, None].float() * freqs[None]
|
| 53 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 54 |
+
if dim % 2:
|
| 55 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 56 |
+
return embedding
|
| 57 |
+
|
| 58 |
+
def forward(self, t):
|
| 59 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
| 60 |
+
t_emb = self.mlp(t_freq)
|
| 61 |
+
return t_emb
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class LabelEmbedder(nn.Module):
|
| 65 |
+
"""
|
| 66 |
+
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
| 67 |
+
"""
|
| 68 |
+
def __init__(self, num_classes, hidden_size, dropout_prob):
|
| 69 |
+
super().__init__()
|
| 70 |
+
use_cfg_embedding = dropout_prob > 0
|
| 71 |
+
self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
|
| 72 |
+
self.num_classes = num_classes
|
| 73 |
+
self.dropout_prob = dropout_prob
|
| 74 |
+
|
| 75 |
+
def token_drop(self, labels, force_drop_ids=None):
|
| 76 |
+
"""
|
| 77 |
+
Drops labels to enable classifier-free guidance.
|
| 78 |
+
"""
|
| 79 |
+
if force_drop_ids is None:
|
| 80 |
+
drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
|
| 81 |
+
else:
|
| 82 |
+
drop_ids = force_drop_ids == 1
|
| 83 |
+
labels = torch.where(drop_ids, self.num_classes, labels)
|
| 84 |
+
return labels
|
| 85 |
+
|
| 86 |
+
def forward(self, labels, train, force_drop_ids=None):
|
| 87 |
+
use_dropout = self.dropout_prob > 0
|
| 88 |
+
if (train and use_dropout) or (force_drop_ids is not None):
|
| 89 |
+
labels = self.token_drop(labels, force_drop_ids)
|
| 90 |
+
embeddings = self.embedding_table(labels)
|
| 91 |
+
return embeddings
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
#################################################################################
|
| 95 |
+
# Core SiT Model #
|
| 96 |
+
#################################################################################
|
| 97 |
+
|
| 98 |
+
class SiTBlock(nn.Module):
|
| 99 |
+
"""
|
| 100 |
+
A SiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
|
| 101 |
+
"""
|
| 102 |
+
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
|
| 103 |
+
super().__init__()
|
| 104 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 105 |
+
self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
|
| 106 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 107 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 108 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
| 109 |
+
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
|
| 110 |
+
self.adaLN_modulation = nn.Sequential(
|
| 111 |
+
nn.SiLU(),
|
| 112 |
+
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
def forward(self, x, c):
|
| 116 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
|
| 117 |
+
x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
|
| 118 |
+
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
| 119 |
+
return x
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class FinalLayer(nn.Module):
|
| 123 |
+
"""
|
| 124 |
+
The final layer of SiT.
|
| 125 |
+
"""
|
| 126 |
+
def __init__(self, hidden_size, patch_size, out_channels):
|
| 127 |
+
super().__init__()
|
| 128 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 129 |
+
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
| 130 |
+
self.adaLN_modulation = nn.Sequential(
|
| 131 |
+
nn.SiLU(),
|
| 132 |
+
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
def forward(self, x, c):
|
| 136 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
| 137 |
+
x = modulate(self.norm_final(x), shift, scale)
|
| 138 |
+
x = self.linear(x)
|
| 139 |
+
return x
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class SiT(nn.Module):
|
| 143 |
+
"""
|
| 144 |
+
Diffusion model with a Transformer backbone.
|
| 145 |
+
"""
|
| 146 |
+
def __init__(
|
| 147 |
+
self,
|
| 148 |
+
input_size=32,
|
| 149 |
+
patch_size=2,
|
| 150 |
+
in_channels=4,
|
| 151 |
+
hidden_size=1152,
|
| 152 |
+
depth=28,
|
| 153 |
+
num_heads=16,
|
| 154 |
+
mlp_ratio=4.0,
|
| 155 |
+
class_dropout_prob=0.1,
|
| 156 |
+
num_classes=1000,
|
| 157 |
+
learn_sigma=True,
|
| 158 |
+
):
|
| 159 |
+
super().__init__()
|
| 160 |
+
self.learn_sigma = learn_sigma
|
| 161 |
+
self.learn_sigma = True
|
| 162 |
+
self.in_channels = in_channels
|
| 163 |
+
self.out_channels = in_channels * 2
|
| 164 |
+
self.patch_size = patch_size
|
| 165 |
+
self.num_heads = num_heads
|
| 166 |
+
|
| 167 |
+
self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
|
| 168 |
+
self.t_embedder = TimestepEmbedder(hidden_size)
|
| 169 |
+
self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
|
| 170 |
+
num_patches = self.x_embedder.num_patches
|
| 171 |
+
# Will use fixed sin-cos embedding:
|
| 172 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
|
| 173 |
+
|
| 174 |
+
self.blocks = nn.ModuleList([
|
| 175 |
+
SiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
|
| 176 |
+
])
|
| 177 |
+
self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
|
| 178 |
+
self.initialize_weights()
|
| 179 |
+
|
| 180 |
+
def initialize_weights(self):
|
| 181 |
+
# Initialize transformer layers:
|
| 182 |
+
def _basic_init(module):
|
| 183 |
+
if isinstance(module, nn.Linear):
|
| 184 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 185 |
+
if module.bias is not None:
|
| 186 |
+
nn.init.constant_(module.bias, 0)
|
| 187 |
+
self.apply(_basic_init)
|
| 188 |
+
|
| 189 |
+
# Initialize (and freeze) pos_embed by sin-cos embedding:
|
| 190 |
+
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
|
| 191 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
| 192 |
+
|
| 193 |
+
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
| 194 |
+
w = self.x_embedder.proj.weight.data
|
| 195 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
| 196 |
+
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
| 197 |
+
|
| 198 |
+
# Initialize label embedding table:
|
| 199 |
+
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
|
| 200 |
+
|
| 201 |
+
# Initialize timestep embedding MLP:
|
| 202 |
+
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
| 203 |
+
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
| 204 |
+
|
| 205 |
+
# Zero-out adaLN modulation layers in SiT blocks:
|
| 206 |
+
for block in self.blocks:
|
| 207 |
+
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
| 208 |
+
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
| 209 |
+
|
| 210 |
+
# Zero-out output layers:
|
| 211 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
| 212 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
| 213 |
+
nn.init.constant_(self.final_layer.linear.weight, 0)
|
| 214 |
+
nn.init.constant_(self.final_layer.linear.bias, 0)
|
| 215 |
+
|
| 216 |
+
def unpatchify(self, x):
|
| 217 |
+
"""
|
| 218 |
+
x: (N, T, patch_size**2 * C)
|
| 219 |
+
imgs: (N, H, W, C)
|
| 220 |
+
"""
|
| 221 |
+
c = self.out_channels
|
| 222 |
+
p = self.x_embedder.patch_size[0]
|
| 223 |
+
h = w = int(x.shape[1] ** 0.5)
|
| 224 |
+
assert h * w == x.shape[1]
|
| 225 |
+
|
| 226 |
+
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
| 227 |
+
x = torch.einsum('nhwpqc->nchpwq', x)
|
| 228 |
+
imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
|
| 229 |
+
return imgs
|
| 230 |
+
|
| 231 |
+
def forward(self, x, t, y, return_act=False):
|
| 232 |
+
"""
|
| 233 |
+
Forward pass of SiT.
|
| 234 |
+
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
| 235 |
+
t: (N,) tensor of diffusion timesteps
|
| 236 |
+
y: (N,) tensor of class labels
|
| 237 |
+
return_act: if True, return activations from transformer blocks
|
| 238 |
+
"""
|
| 239 |
+
act = []
|
| 240 |
+
x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
|
| 241 |
+
t = self.t_embedder(t) # (N, D)
|
| 242 |
+
y = self.y_embedder(y, self.training) # (N, D)
|
| 243 |
+
c = t + y # (N, D)
|
| 244 |
+
for block in self.blocks:
|
| 245 |
+
x = block(x, c) # (N, T, D)
|
| 246 |
+
if return_act:
|
| 247 |
+
act.append(x)
|
| 248 |
+
x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
|
| 249 |
+
x = self.unpatchify(x) # (N, out_channels, H, W)
|
| 250 |
+
if self.learn_sigma:
|
| 251 |
+
x, _ = x.chunk(2, dim=1)
|
| 252 |
+
if return_act:
|
| 253 |
+
return x, act
|
| 254 |
+
return x
|
| 255 |
+
|
| 256 |
+
def forward_with_cfg(self, x, t, y, cfg_scale):
|
| 257 |
+
"""
|
| 258 |
+
Forward pass of SiT, but also batches the unconSiTional forward pass for classifier-free guidance.
|
| 259 |
+
"""
|
| 260 |
+
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
|
| 261 |
+
half = x[: len(x) // 2]
|
| 262 |
+
combined = torch.cat([half, half], dim=0)
|
| 263 |
+
model_out = self.forward(combined, t, y)
|
| 264 |
+
# For exact reproducibility reasons, we apply classifier-free guidance on only
|
| 265 |
+
# three channels by default. The standard approach to cfg applies it to all channels.
|
| 266 |
+
# This can be done by uncommenting the following line and commenting-out the line following that.
|
| 267 |
+
# eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
|
| 268 |
+
eps, rest = model_out[:, :3], model_out[:, 3:]
|
| 269 |
+
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
|
| 270 |
+
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
|
| 271 |
+
eps = torch.cat([half_eps, half_eps], dim=0)
|
| 272 |
+
return torch.cat([eps, rest], dim=1)
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
#################################################################################
|
| 276 |
+
# Sine/Cosine Positional Embedding Functions #
|
| 277 |
+
#################################################################################
|
| 278 |
+
# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
|
| 279 |
+
|
| 280 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
|
| 281 |
+
"""
|
| 282 |
+
grid_size: int of the grid height and width
|
| 283 |
+
return:
|
| 284 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
| 285 |
+
"""
|
| 286 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
| 287 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
| 288 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
| 289 |
+
grid = np.stack(grid, axis=0)
|
| 290 |
+
|
| 291 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
| 292 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| 293 |
+
if cls_token and extra_tokens > 0:
|
| 294 |
+
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
| 295 |
+
return pos_embed
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
| 299 |
+
assert embed_dim % 2 == 0
|
| 300 |
+
|
| 301 |
+
# use half of dimensions to encode grid_h
|
| 302 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
| 303 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
| 304 |
+
|
| 305 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
| 306 |
+
return emb
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
| 310 |
+
"""
|
| 311 |
+
embed_dim: output dimension for each position
|
| 312 |
+
pos: a list of positions to be encoded: size (M,)
|
| 313 |
+
out: (M, D)
|
| 314 |
+
"""
|
| 315 |
+
assert embed_dim % 2 == 0
|
| 316 |
+
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
| 317 |
+
omega /= embed_dim / 2.
|
| 318 |
+
omega = 1. / 10000**omega # (D/2,)
|
| 319 |
+
|
| 320 |
+
pos = pos.reshape(-1) # (M,)
|
| 321 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
| 322 |
+
|
| 323 |
+
emb_sin = np.sin(out) # (M, D/2)
|
| 324 |
+
emb_cos = np.cos(out) # (M, D/2)
|
| 325 |
+
|
| 326 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
| 327 |
+
return emb
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
#################################################################################
|
| 331 |
+
# SiT Configs #
|
| 332 |
+
#################################################################################
|
| 333 |
+
|
| 334 |
+
def SiT_XL_2(**kwargs):
|
| 335 |
+
return SiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
|
| 336 |
+
|
| 337 |
+
def SiT_XL_4(**kwargs):
|
| 338 |
+
return SiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs)
|
| 339 |
+
|
| 340 |
+
def SiT_XL_8(**kwargs):
|
| 341 |
+
return SiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs)
|
| 342 |
+
|
| 343 |
+
def SiT_L_2(**kwargs):
|
| 344 |
+
return SiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs)
|
| 345 |
+
|
| 346 |
+
def SiT_L_4(**kwargs):
|
| 347 |
+
return SiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs)
|
| 348 |
+
|
| 349 |
+
def SiT_L_8(**kwargs):
|
| 350 |
+
return SiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs)
|
| 351 |
+
|
| 352 |
+
def SiT_B_2(**kwargs):
|
| 353 |
+
return SiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs)
|
| 354 |
+
|
| 355 |
+
def SiT_B_4(**kwargs):
|
| 356 |
+
return SiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs)
|
| 357 |
+
|
| 358 |
+
def SiT_B_8(**kwargs):
|
| 359 |
+
return SiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs)
|
| 360 |
+
|
| 361 |
+
def SiT_S_2(**kwargs):
|
| 362 |
+
return SiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
|
| 363 |
+
|
| 364 |
+
def SiT_S_4(**kwargs):
|
| 365 |
+
return SiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs)
|
| 366 |
+
|
| 367 |
+
def SiT_S_8(**kwargs):
|
| 368 |
+
return SiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
SiT_models = {
|
| 372 |
+
'SiT-XL/2': SiT_XL_2, 'SiT-XL/4': SiT_XL_4, 'SiT-XL/8': SiT_XL_8,
|
| 373 |
+
'SiT-L/2': SiT_L_2, 'SiT-L/4': SiT_L_4, 'SiT-L/8': SiT_L_8,
|
| 374 |
+
'SiT-B/2': SiT_B_2, 'SiT-B/4': SiT_B_4, 'SiT-B/8': SiT_B_8,
|
| 375 |
+
'SiT-S/2': SiT_S_2, 'SiT-S/4': SiT_S_4, 'SiT-S/8': SiT_S_8,
|
| 376 |
+
}
|
| 377 |
+
|
| 378 |
+
#################################################################################
|
| 379 |
+
# SiTF1, SiTF2, CombinedModel #
|
| 380 |
+
#################################################################################
|
| 381 |
+
|
| 382 |
+
class SiTF1(nn.Module):
|
| 383 |
+
"""
|
| 384 |
+
SiTF1 Model
|
| 385 |
+
"""
|
| 386 |
+
def __init__(
|
| 387 |
+
self,
|
| 388 |
+
input_size=32,
|
| 389 |
+
patch_size=2,
|
| 390 |
+
in_channels=4,
|
| 391 |
+
hidden_size=1152,
|
| 392 |
+
depth=28,
|
| 393 |
+
num_heads=16,
|
| 394 |
+
mlp_ratio=4.0,
|
| 395 |
+
class_dropout_prob=0.1,
|
| 396 |
+
num_classes=1000,
|
| 397 |
+
learn_sigma=True,
|
| 398 |
+
final_layer=None,
|
| 399 |
+
):
|
| 400 |
+
super().__init__()
|
| 401 |
+
self.input_size = input_size
|
| 402 |
+
self.patch_size= patch_size
|
| 403 |
+
self.hidden_size = hidden_size
|
| 404 |
+
self.in_channels = in_channels
|
| 405 |
+
self.out_channels = in_channels * 2
|
| 406 |
+
self.patch_size = patch_size
|
| 407 |
+
self.num_heads = num_heads
|
| 408 |
+
self.learn_sigma = learn_sigma
|
| 409 |
+
self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
|
| 410 |
+
self.t_embedder = TimestepEmbedder(hidden_size)
|
| 411 |
+
self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
|
| 412 |
+
num_patches = self.x_embedder.num_patches
|
| 413 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
|
| 414 |
+
self.blocks = nn.ModuleList([
|
| 415 |
+
SiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
|
| 416 |
+
])
|
| 417 |
+
self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
|
| 418 |
+
self.initialize_weights()
|
| 419 |
+
|
| 420 |
+
def unpatchify(self, x):
|
| 421 |
+
"""
|
| 422 |
+
x: (N, T, patch_size**2 * C)
|
| 423 |
+
imgs: (N, H, W, C)
|
| 424 |
+
"""
|
| 425 |
+
c = self.out_channels
|
| 426 |
+
p = self.x_embedder.patch_size[0]
|
| 427 |
+
h = w = int(x.shape[1] ** 0.5)
|
| 428 |
+
assert h * w == x.shape[1]
|
| 429 |
+
|
| 430 |
+
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
| 431 |
+
x = torch.einsum('nhwpqc->nchpwq', x)
|
| 432 |
+
imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
|
| 433 |
+
return imgs
|
| 434 |
+
|
| 435 |
+
def initialize_weights(self):
|
| 436 |
+
def _basic_init(module):
|
| 437 |
+
if isinstance(module, nn.Linear):
|
| 438 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 439 |
+
if module.bias is not None:
|
| 440 |
+
nn.init.constant_(module.bias, 0)
|
| 441 |
+
self.apply(_basic_init)
|
| 442 |
+
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
|
| 443 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
| 444 |
+
w = self.x_embedder.proj.weight.data
|
| 445 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
| 446 |
+
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
| 447 |
+
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
|
| 448 |
+
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
| 449 |
+
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
| 450 |
+
for block in self.blocks:
|
| 451 |
+
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
| 452 |
+
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
| 453 |
+
|
| 454 |
+
def forward(self, x, t, y):
|
| 455 |
+
x = self.x_embedder(x) + self.pos_embed
|
| 456 |
+
t = self.t_embedder(t)
|
| 457 |
+
y = self.y_embedder(y, self.training)
|
| 458 |
+
c = t + y
|
| 459 |
+
for block in self.blocks:
|
| 460 |
+
x = block(x, c)
|
| 461 |
+
x_now = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
|
| 462 |
+
x_now = self.unpatchify(x_now) # (N, out_channels, H, W)
|
| 463 |
+
x_now, _ = x_now.chunk(2, dim=1)
|
| 464 |
+
return x,x_now # patch token (N, T, D)
|
| 465 |
+
|
| 466 |
+
def forward_with_cfg(self, x, t, y, cfg_scale):
|
| 467 |
+
"""
|
| 468 |
+
Forward pass with classifier-free guidance for SiTF1.
|
| 469 |
+
Applies guidance consistently to both patch tokens and image output (x_now).
|
| 470 |
+
"""
|
| 471 |
+
# Take the first half (conditional inputs) and duplicate it so that
|
| 472 |
+
# it can be paired with conditional and unconditional labels in `y`.
|
| 473 |
+
half = x[: len(x) // 2]
|
| 474 |
+
combined = torch.cat([half, half], dim=0)
|
| 475 |
+
patch_tokens, x_now = self.forward(combined, t, y)
|
| 476 |
+
|
| 477 |
+
# Apply CFG on the image output channels (first 3 channels by default)
|
| 478 |
+
eps, rest = x_now[:, :3, ...], x_now[:, 3:, ...]
|
| 479 |
+
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
|
| 480 |
+
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
|
| 481 |
+
eps = torch.cat([half_eps, half_eps], dim=0)
|
| 482 |
+
x_now = torch.cat([eps, rest], dim=1)
|
| 483 |
+
|
| 484 |
+
# Apply same guidance logic to patch tokens so downstream modules see
|
| 485 |
+
# a consistent guided representation.
|
| 486 |
+
cond_tok, uncond_tok = torch.split(patch_tokens, len(patch_tokens) // 2, dim=0)
|
| 487 |
+
half_tok = uncond_tok + cfg_scale * (cond_tok - uncond_tok)
|
| 488 |
+
patch_tokens = torch.cat([half_tok, half_tok], dim=0)
|
| 489 |
+
|
| 490 |
+
return patch_tokens, x_now
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
class SiTF2(nn.Module):
|
| 494 |
+
"""
|
| 495 |
+
SiTF2:
|
| 496 |
+
"""
|
| 497 |
+
def __init__(
|
| 498 |
+
self,
|
| 499 |
+
input_size=32,
|
| 500 |
+
hidden_size=1152,
|
| 501 |
+
out_channels=8,
|
| 502 |
+
patch_size=2,
|
| 503 |
+
num_heads=16,
|
| 504 |
+
mlp_ratio=4.0,
|
| 505 |
+
depth=4,
|
| 506 |
+
learn_sigma=True,
|
| 507 |
+
final_layer=None,
|
| 508 |
+
num_classes=1000,
|
| 509 |
+
class_dropout_prob=0.1,
|
| 510 |
+
learn_mu=False,
|
| 511 |
+
):
|
| 512 |
+
super().__init__()
|
| 513 |
+
self.learn_sigma = learn_sigma
|
| 514 |
+
self.learn_mu = learn_mu
|
| 515 |
+
self.out_channels = out_channels
|
| 516 |
+
self.in_channels = 4
|
| 517 |
+
self.patch_size = patch_size
|
| 518 |
+
self.num_heads = num_heads
|
| 519 |
+
self.blocks = nn.ModuleList([
|
| 520 |
+
SiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
|
| 521 |
+
])
|
| 522 |
+
self.x_embedder = PatchEmbed(input_size, patch_size, self.in_channels, hidden_size, bias=True)
|
| 523 |
+
self.t_embedder = TimestepEmbedder(hidden_size)
|
| 524 |
+
self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
|
| 525 |
+
num_patches = self.x_embedder.num_patches
|
| 526 |
+
self.num_patches = num_patches # Save original num_patches for unpatchify
|
| 527 |
+
# pos_embed needs to support 2*num_patches for concatenated input
|
| 528 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, 2 * num_patches, hidden_size), requires_grad=False)
|
| 529 |
+
# Initialize pos_embed with sin-cos embedding
|
| 530 |
+
pos_embed = get_2d_sincos_pos_embed(hidden_size, int(num_patches ** 0.5))
|
| 531 |
+
# Repeat the pos_embed for both halves (or could use different embeddings)
|
| 532 |
+
pos_embed_full = np.concatenate([pos_embed, pos_embed], axis=0)
|
| 533 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed_full).float().unsqueeze(0))
|
| 534 |
+
|
| 535 |
+
if final_layer is not None:
|
| 536 |
+
self.final_layer = final_layer
|
| 537 |
+
else:
|
| 538 |
+
self.final_layer = FinalLayer(hidden_size, patch_size, out_channels)
|
| 539 |
+
# if depth !=0:
|
| 540 |
+
# for p in self.final_layer.parameters():
|
| 541 |
+
# if p is not None:
|
| 542 |
+
# torch.nn.init.constant_(p, 0)
|
| 543 |
+
|
| 544 |
+
def unpatchify(self, x, patch_size, out_channels):
|
| 545 |
+
c = out_channels
|
| 546 |
+
p = patch_size
|
| 547 |
+
# x.shape[1] might be 2*num_patches when using concatenated input
|
| 548 |
+
# Use original num_patches to calculate h and w
|
| 549 |
+
h = w = int(self.num_patches ** 0.5)
|
| 550 |
+
# If input has 2*num_patches, we need to handle it
|
| 551 |
+
if x.shape[1] == 2 * self.num_patches:
|
| 552 |
+
# Take only the first half (or average, or other strategy)
|
| 553 |
+
# For now, we'll take the first half
|
| 554 |
+
x = x[:, :self.num_patches, :]
|
| 555 |
+
assert h * w == x.shape[1], f"Expected {h * w} patches, got {x.shape[1]}"
|
| 556 |
+
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
| 557 |
+
x = torch.einsum('nhwpqc->nchpwq', x)
|
| 558 |
+
imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
|
| 559 |
+
return imgs
|
| 560 |
+
|
| 561 |
+
def forward(self, x, c, t, return_act=False):
|
| 562 |
+
act = []
|
| 563 |
+
for block in self.blocks:
|
| 564 |
+
x = block(x, c)
|
| 565 |
+
if return_act:
|
| 566 |
+
act.append(x)
|
| 567 |
+
x = self.final_layer(x, c)
|
| 568 |
+
x = self.unpatchify(x, self.patch_size, self.out_channels)
|
| 569 |
+
if self.learn_sigma:
|
| 570 |
+
mean_pred, log_var_pred = x.chunk(2, dim=1)
|
| 571 |
+
variance_pred = torch.exp(log_var_pred)
|
| 572 |
+
std_dev_pred = torch.sqrt(variance_pred)
|
| 573 |
+
noise = torch.randn_like(mean_pred)
|
| 574 |
+
#uniform_noise = torch.rand_like(mean_pred)
|
| 575 |
+
#uniform_noise = uniform_noise.clamp(min=1e-5, max=1-1e-5)
|
| 576 |
+
#gumbel_noise = -torch.log(-torch.log(uniform_noise))
|
| 577 |
+
|
| 578 |
+
if self.learn_mu==True:
|
| 579 |
+
resampled_x = mean_pred + std_dev_pred * noise
|
| 580 |
+
else:
|
| 581 |
+
resampled_x = std_dev_pred * noise
|
| 582 |
+
x = resampled_x
|
| 583 |
+
else:
|
| 584 |
+
x, _ = x.chunk(2, dim=1)
|
| 585 |
+
if return_act:
|
| 586 |
+
return x, act
|
| 587 |
+
return x
|
| 588 |
+
|
| 589 |
+
def forward_noise(self, x, c):
|
| 590 |
+
for block in self.blocks:
|
| 591 |
+
x = block(x, c)
|
| 592 |
+
x = self.final_layer(x, c)
|
| 593 |
+
x = self.unpatchify(x, self.patch_size, self.out_channels)
|
| 594 |
+
if self.learn_sigma:
|
| 595 |
+
mean_pred, log_var_pred = x.chunk(2, dim=1)
|
| 596 |
+
variance_pred = torch.exp(log_var_pred)
|
| 597 |
+
std_dev_pred = torch.sqrt(variance_pred)
|
| 598 |
+
noise = torch.randn_like(mean_pred)
|
| 599 |
+
if self.learn_mu==True:
|
| 600 |
+
resampled_x = mean_pred + std_dev_pred * noise
|
| 601 |
+
else:
|
| 602 |
+
resampled_x = std_dev_pred * noise
|
| 603 |
+
x = resampled_x
|
| 604 |
+
else:
|
| 605 |
+
x, _ = x.chunk(2, dim=1)
|
| 606 |
+
return x
|
| 607 |
+
|
| 608 |
+
#有两种写法,一种是拿理想的,一种是拿真实的,一种是拼接,一种是加和
|
| 609 |
+
class CombinedModel(nn.Module):
|
| 610 |
+
"""
|
| 611 |
+
CombinedModel。
|
| 612 |
+
"""
|
| 613 |
+
def __init__(self, sitf1: SiTF1, sitf2: SiTF2):
|
| 614 |
+
super().__init__()
|
| 615 |
+
self.sitf1 = sitf1
|
| 616 |
+
self.sitf2 = sitf2
|
| 617 |
+
input_size=self.sitf1.input_size
|
| 618 |
+
patch_size=self.sitf1.patch_size
|
| 619 |
+
hidden_size=self.sitf1.hidden_size
|
| 620 |
+
self.x_embedder = PatchEmbed(input_size, patch_size, 4, hidden_size, bias=True)
|
| 621 |
+
num_patches = self.x_embedder.num_patches
|
| 622 |
+
# pos_embed needs to support 2*num_patches for concatenated input
|
| 623 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, 2 * num_patches, hidden_size), requires_grad=False)
|
| 624 |
+
# Initialize pos_embed with sin-cos embedding
|
| 625 |
+
pos_embed = get_2d_sincos_pos_embed(hidden_size, int(num_patches ** 0.5))
|
| 626 |
+
# Repeat the pos_embed for both halves (or could use different embeddings)
|
| 627 |
+
pos_embed_full = np.concatenate([pos_embed, pos_embed], axis=0)
|
| 628 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed_full).float().unsqueeze(0))
|
| 629 |
+
|
| 630 |
+
def forward(self, x, t, y, return_act=False):
|
| 631 |
+
patch_tokens,x_now = self.sitf1(x, t, y)
|
| 632 |
+
# Interpolate between x_now and x using timestep t: (1-t)*x_now + t*x
|
| 633 |
+
# t shape is (N,), need to broadcast to (N, 1, 1, 1) for broadcasting with image (N, C, H, W)
|
| 634 |
+
t_broadcast = t.view(-1, 1, 1, 1) # (N, 1, 1, 1)
|
| 635 |
+
# Compute interpolated input: (1-t)*x_now + t*x
|
| 636 |
+
x_interpolated = (1 - t_broadcast) * x_now + x
|
| 637 |
+
# Convert interpolated input (image format) back to patch token format (without pos_embed, will add later)
|
| 638 |
+
x_now_patches = self.x_embedder(x_interpolated)
|
| 639 |
+
# Concatenate patch_tokens and x_now_patches along the sequence dimension
|
| 640 |
+
concatenated_input = torch.cat([patch_tokens, x_now_patches], dim=1) # (N, 2*T, D)
|
| 641 |
+
# Add position embedding for the concatenated input
|
| 642 |
+
# Use the same pos_embed for both halves (or could use different embeddings)
|
| 643 |
+
concatenated_input = concatenated_input + self.pos_embed
|
| 644 |
+
t_emb = self.sitf1.t_embedder(t)
|
| 645 |
+
y_emb = self.sitf1.y_embedder(y, self.training)
|
| 646 |
+
c = t_emb + y_emb
|
| 647 |
+
return self.sitf2(concatenated_input, c, t, return_act=return_act)
|
Rectified_Noise/GVP-Disp/sample_ddp.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This source code is licensed under the license found in the
|
| 2 |
+
# LICENSE file in the root directory of this source tree.
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
Samples a large number of images from a pre-trained SiT model using DDP.
|
| 6 |
+
Subsequently saves a .npz file that can be used to compute FID and other
|
| 7 |
+
evaluation metrics via the ADM repo: https://github.com/openai/guided-diffusion/tree/main/evaluations
|
| 8 |
+
|
| 9 |
+
For a simple single-GPU/CPU sampling script, see sample.py.
|
| 10 |
+
"""
|
| 11 |
+
import torch
|
| 12 |
+
import torch.distributed as dist
|
| 13 |
+
from models import SiT_models
|
| 14 |
+
from download import find_model
|
| 15 |
+
from transport import create_transport, Sampler
|
| 16 |
+
from diffusers.models import AutoencoderKL
|
| 17 |
+
from train_utils import parse_ode_args, parse_sde_args, parse_transport_args
|
| 18 |
+
from tqdm import tqdm
|
| 19 |
+
import os
|
| 20 |
+
from PIL import Image
|
| 21 |
+
import numpy as np
|
| 22 |
+
import math
|
| 23 |
+
import argparse
|
| 24 |
+
import sys
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def create_npz_from_sample_folder(sample_dir, num=50_000):
|
| 28 |
+
"""
|
| 29 |
+
Builds a single .npz file from a folder of .png samples.
|
| 30 |
+
"""
|
| 31 |
+
samples = []
|
| 32 |
+
for i in tqdm(range(num), desc="Building .npz file from samples"):
|
| 33 |
+
sample_pil = Image.open(f"{sample_dir}/{i:06d}.png")
|
| 34 |
+
sample_np = np.asarray(sample_pil).astype(np.uint8)
|
| 35 |
+
samples.append(sample_np)
|
| 36 |
+
samples = np.stack(samples)
|
| 37 |
+
assert samples.shape == (num, samples.shape[1], samples.shape[2], 3)
|
| 38 |
+
npz_path = f"{sample_dir}.npz"
|
| 39 |
+
np.savez(npz_path, arr_0=samples)
|
| 40 |
+
print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
|
| 41 |
+
return npz_path
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def main(mode, args):
|
| 45 |
+
"""
|
| 46 |
+
Run sampling.
|
| 47 |
+
"""
|
| 48 |
+
torch.backends.cuda.matmul.allow_tf32 = args.tf32 # True: fast but may lead to some small numerical differences
|
| 49 |
+
assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
|
| 50 |
+
torch.set_grad_enabled(False)
|
| 51 |
+
|
| 52 |
+
# Setup DDP:
|
| 53 |
+
dist.init_process_group("nccl")
|
| 54 |
+
rank = dist.get_rank()
|
| 55 |
+
device = rank % torch.cuda.device_count()
|
| 56 |
+
seed = args.global_seed * dist.get_world_size() + rank
|
| 57 |
+
torch.manual_seed(seed)
|
| 58 |
+
torch.cuda.set_device(device)
|
| 59 |
+
print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
|
| 60 |
+
|
| 61 |
+
if args.ckpt is None:
|
| 62 |
+
assert args.model == "SiT-XL/2", "Only SiT-XL/2 models are available for auto-download."
|
| 63 |
+
assert args.image_size in [256, 512]
|
| 64 |
+
assert args.num_classes == 1000
|
| 65 |
+
assert args.image_size == 256, "512x512 models are not yet available for auto-download." # remove this line when 512x512 models are available
|
| 66 |
+
learn_sigma = args.image_size == 256
|
| 67 |
+
else:
|
| 68 |
+
learn_sigma = False
|
| 69 |
+
|
| 70 |
+
# Load model:
|
| 71 |
+
latent_size = args.image_size // 8
|
| 72 |
+
model = SiT_models[args.model](
|
| 73 |
+
input_size=latent_size,
|
| 74 |
+
num_classes=args.num_classes,
|
| 75 |
+
learn_sigma=learn_sigma,
|
| 76 |
+
).to(device)
|
| 77 |
+
# Auto-download a pre-trained model or load a custom SiT checkpoint from train.py:
|
| 78 |
+
ckpt_path = args.ckpt or f"SiT-XL-2-{args.image_size}x{args.image_size}.pt"
|
| 79 |
+
state_dict = find_model(ckpt_path)
|
| 80 |
+
model.load_state_dict(state_dict)
|
| 81 |
+
model.eval() # important!
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
transport = create_transport(
|
| 85 |
+
args.path_type,
|
| 86 |
+
args.prediction,
|
| 87 |
+
args.loss_weight,
|
| 88 |
+
args.train_eps,
|
| 89 |
+
args.sample_eps
|
| 90 |
+
)
|
| 91 |
+
sampler = Sampler(transport)
|
| 92 |
+
if mode == "ODE":
|
| 93 |
+
if args.likelihood:
|
| 94 |
+
assert args.cfg_scale == 1, "Likelihood is incompatible with guidance"
|
| 95 |
+
sample_fn = sampler.sample_ode_likelihood(
|
| 96 |
+
sampling_method=args.sampling_method,
|
| 97 |
+
num_steps=args.num_sampling_steps,
|
| 98 |
+
atol=args.atol,
|
| 99 |
+
rtol=args.rtol,
|
| 100 |
+
)
|
| 101 |
+
else:
|
| 102 |
+
sample_fn = sampler.sample_ode(
|
| 103 |
+
sampling_method=args.sampling_method,
|
| 104 |
+
num_steps=args.num_sampling_steps,
|
| 105 |
+
atol=args.atol,
|
| 106 |
+
rtol=args.rtol,
|
| 107 |
+
reverse=args.reverse
|
| 108 |
+
)
|
| 109 |
+
elif mode == "SDE":
|
| 110 |
+
sample_fn = sampler.sample_sde(
|
| 111 |
+
sampling_method=args.sampling_method,
|
| 112 |
+
diffusion_form=args.diffusion_form,
|
| 113 |
+
diffusion_norm=args.diffusion_norm,
|
| 114 |
+
last_step=args.last_step,
|
| 115 |
+
last_step_size=args.last_step_size,
|
| 116 |
+
num_steps=args.num_sampling_steps,
|
| 117 |
+
)
|
| 118 |
+
vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device)
|
| 119 |
+
assert args.cfg_scale >= 1.0, "In almost all cases, cfg_scale be >= 1.0"
|
| 120 |
+
using_cfg = args.cfg_scale > 1.0
|
| 121 |
+
|
| 122 |
+
# Create folder to save samples:
|
| 123 |
+
model_string_name = args.model.replace("/", "-")
|
| 124 |
+
ckpt_string_name = os.path.basename(args.ckpt).replace(".pt", "") if args.ckpt else "pretrained"
|
| 125 |
+
if mode == "ODE":
|
| 126 |
+
folder_name = f"{model_string_name}-{ckpt_string_name}-" \
|
| 127 |
+
f"cfg-{args.cfg_scale}-{args.per_proc_batch_size}-"\
|
| 128 |
+
f"{mode}-{args.num_sampling_steps}-{args.sampling_method}"
|
| 129 |
+
elif mode == "SDE":
|
| 130 |
+
folder_name = f"{model_string_name}-{ckpt_string_name}-" \
|
| 131 |
+
f"cfg-{args.cfg_scale}-{args.per_proc_batch_size}-"\
|
| 132 |
+
f"{mode}-{args.num_sampling_steps}-{args.sampling_method}-"\
|
| 133 |
+
f"{args.diffusion_form}-{args.last_step}-{args.last_step_size}"
|
| 134 |
+
sample_folder_dir = f"{args.sample_dir}/{folder_name}"
|
| 135 |
+
if rank == 0:
|
| 136 |
+
os.makedirs(sample_folder_dir, exist_ok=True)
|
| 137 |
+
print(f"Saving .png samples at {sample_folder_dir}")
|
| 138 |
+
dist.barrier()
|
| 139 |
+
|
| 140 |
+
# Figure out how many samples we need to generate on each GPU and how many iterations we need to run:
|
| 141 |
+
n = args.per_proc_batch_size
|
| 142 |
+
global_batch_size = n * dist.get_world_size()
|
| 143 |
+
# To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples:
|
| 144 |
+
num_samples = len([name for name in os.listdir(sample_folder_dir) if (os.path.isfile(os.path.join(sample_folder_dir, name)) and ".png" in name)])
|
| 145 |
+
total_samples = int(math.ceil(args.num_fid_samples / global_batch_size) * global_batch_size)
|
| 146 |
+
if rank == 0:
|
| 147 |
+
print(f"Total number of images that will be sampled: {total_samples}")
|
| 148 |
+
assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size"
|
| 149 |
+
samples_needed_this_gpu = int(total_samples // dist.get_world_size())
|
| 150 |
+
assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size"
|
| 151 |
+
iterations = int(samples_needed_this_gpu // n)
|
| 152 |
+
done_iterations = int( int(num_samples // dist.get_world_size()) // n)
|
| 153 |
+
pbar = range(iterations)
|
| 154 |
+
pbar = tqdm(pbar) if rank == 0 else pbar
|
| 155 |
+
total = 0
|
| 156 |
+
|
| 157 |
+
for i in pbar:
|
| 158 |
+
# Sample inputs:
|
| 159 |
+
z = torch.randn(n, model.in_channels, latent_size, latent_size, device=device)
|
| 160 |
+
y = torch.randint(0, args.num_classes, (n,), device=device)
|
| 161 |
+
|
| 162 |
+
# Setup classifier-free guidance:
|
| 163 |
+
if using_cfg:
|
| 164 |
+
z = torch.cat([z, z], 0)
|
| 165 |
+
y_null = torch.tensor([1000] * n, device=device)
|
| 166 |
+
y = torch.cat([y, y_null], 0)
|
| 167 |
+
model_kwargs = dict(y=y, cfg_scale=args.cfg_scale)
|
| 168 |
+
model_fn = model.forward_with_cfg
|
| 169 |
+
else:
|
| 170 |
+
model_kwargs = dict(y=y)
|
| 171 |
+
model_fn = model.forward
|
| 172 |
+
|
| 173 |
+
samples = sample_fn(z, model_fn, **model_kwargs)[-1]
|
| 174 |
+
if using_cfg:
|
| 175 |
+
samples, _ = samples.chunk(2, dim=0) # Remove null class samples
|
| 176 |
+
|
| 177 |
+
samples = vae.decode(samples / 0.18215).sample
|
| 178 |
+
samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
|
| 179 |
+
|
| 180 |
+
# Save samples to disk as individual .png files
|
| 181 |
+
for i, sample in enumerate(samples):
|
| 182 |
+
index = i * dist.get_world_size() + rank + total
|
| 183 |
+
Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png")
|
| 184 |
+
total += global_batch_size
|
| 185 |
+
dist.barrier()
|
| 186 |
+
|
| 187 |
+
# Make sure all processes have finished saving their samples before attempting to convert to .npz
|
| 188 |
+
dist.barrier()
|
| 189 |
+
if rank == 0:
|
| 190 |
+
create_npz_from_sample_folder(sample_folder_dir, args.num_fid_samples)
|
| 191 |
+
print("Done.")
|
| 192 |
+
dist.barrier()
|
| 193 |
+
dist.destroy_process_group()
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
if __name__ == "__main__":
|
| 197 |
+
|
| 198 |
+
parser = argparse.ArgumentParser()
|
| 199 |
+
|
| 200 |
+
if len(sys.argv) < 2:
|
| 201 |
+
print("Usage: program.py <mode> [options]")
|
| 202 |
+
sys.exit(1)
|
| 203 |
+
|
| 204 |
+
mode = sys.argv[1]
|
| 205 |
+
|
| 206 |
+
assert mode[:2] != "--", "Usage: program.py <mode> [options]"
|
| 207 |
+
assert mode in ["ODE", "SDE"], "Invalid mode. Please choose 'ODE' or 'SDE'"
|
| 208 |
+
|
| 209 |
+
parser.add_argument("--model", type=str, choices=list(SiT_models.keys()), default="SiT-XL/2")
|
| 210 |
+
parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema")
|
| 211 |
+
parser.add_argument("--sample-dir", type=str, default="samples")
|
| 212 |
+
parser.add_argument("--per-proc-batch-size", type=int, default=4)
|
| 213 |
+
parser.add_argument("--num-fid-samples", type=int, default=50_000)
|
| 214 |
+
parser.add_argument("--image-size", type=int, choices=[256, 512], default=256)
|
| 215 |
+
parser.add_argument("--num-classes", type=int, default=1000)
|
| 216 |
+
parser.add_argument("--cfg-scale", type=float, default=1.0)
|
| 217 |
+
parser.add_argument("--num-sampling-steps", type=int, default=250)
|
| 218 |
+
parser.add_argument("--global-seed", type=int, default=0)
|
| 219 |
+
parser.add_argument("--tf32", action=argparse.BooleanOptionalAction, default=True,
|
| 220 |
+
help="By default, use TF32 matmuls. This massively accelerates sampling on Ampere GPUs.")
|
| 221 |
+
parser.add_argument("--ckpt", type=str, default=None,
|
| 222 |
+
help="Optional path to a SiT checkpoint (default: auto-download a pre-trained SiT-XL/2 model).")
|
| 223 |
+
|
| 224 |
+
parse_transport_args(parser)
|
| 225 |
+
if mode == "ODE":
|
| 226 |
+
parse_ode_args(parser)
|
| 227 |
+
# Further processing for ODE
|
| 228 |
+
elif mode == "SDE":
|
| 229 |
+
parse_sde_args(parser)
|
| 230 |
+
# Further processing for SDE
|
| 231 |
+
|
| 232 |
+
args = parser.parse_known_args()[0]
|
| 233 |
+
main(mode, args)
|
Rectified_Noise/GVP-Disp/sample_rectified_noise.py
ADDED
|
@@ -0,0 +1,380 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.distributed as dist
|
| 3 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 4 |
+
from models import SiT_models
|
| 5 |
+
from download import find_model
|
| 6 |
+
from transport import create_transport, Sampler
|
| 7 |
+
from diffusers.models import AutoencoderKL
|
| 8 |
+
from train_utils import parse_ode_args, parse_sde_args, parse_transport_args
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
import os
|
| 11 |
+
from PIL import Image
|
| 12 |
+
import numpy as np
|
| 13 |
+
import math
|
| 14 |
+
import argparse
|
| 15 |
+
import sys
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def create_npz_from_sample_folder(sample_dir, num=50_000):
|
| 19 |
+
"""
|
| 20 |
+
Builds a single .npz file from a folder of .png samples.
|
| 21 |
+
"""
|
| 22 |
+
samples = []
|
| 23 |
+
for i in tqdm(range(num), desc="Building .npz file from samples"):
|
| 24 |
+
sample_pil = Image.open(f"{sample_dir}/{i:06d}.png")
|
| 25 |
+
sample_np = np.asarray(sample_pil).astype(np.uint8)
|
| 26 |
+
samples.append(sample_np)
|
| 27 |
+
samples = np.stack(samples)
|
| 28 |
+
assert samples.shape == (num, samples.shape[1], samples.shape[2], 3)
|
| 29 |
+
npz_path = f"{sample_dir}.npz"
|
| 30 |
+
np.savez(npz_path, arr_0=samples)
|
| 31 |
+
print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
|
| 32 |
+
return npz_path
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def fix_state_dict_for_ddp(state_dict):
|
| 36 |
+
"""
|
| 37 |
+
Fix state dict keys to match DistributedDataParallel model keys.
|
| 38 |
+
Add "module." prefix to keys if they don't have it.
|
| 39 |
+
"""
|
| 40 |
+
# Check if this is a full checkpoint dict with "model", "ema", or "opt" keys
|
| 41 |
+
if isinstance(state_dict, dict) and ("model" in state_dict or "ema" in state_dict or "opt" in state_dict):
|
| 42 |
+
# This is a full checkpoint dict, extract the state dict we need
|
| 43 |
+
# Prefer "ema" then "model" then return as is
|
| 44 |
+
if "ema" in state_dict:
|
| 45 |
+
state_dict = state_dict["ema"]
|
| 46 |
+
elif "model" in state_dict:
|
| 47 |
+
state_dict = state_dict["model"]
|
| 48 |
+
else:
|
| 49 |
+
# If only "opt" or other keys exist, return original
|
| 50 |
+
state_dict = state_dict
|
| 51 |
+
|
| 52 |
+
# Now fix the keys to match DDP format
|
| 53 |
+
fixed_state_dict = {}
|
| 54 |
+
for key, value in state_dict.items():
|
| 55 |
+
if not key.startswith("module."):
|
| 56 |
+
new_key = "module." + key
|
| 57 |
+
else:
|
| 58 |
+
new_key = key
|
| 59 |
+
fixed_state_dict[new_key] = value
|
| 60 |
+
return fixed_state_dict
|
| 61 |
+
|
| 62 |
+
def main(mode, args):
|
| 63 |
+
"""
|
| 64 |
+
Run sampling.
|
| 65 |
+
"""
|
| 66 |
+
torch.backends.cuda.matmul.allow_tf32 = args.tf32 # True: fast but may lead to some small numerical differences
|
| 67 |
+
assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
|
| 68 |
+
torch.set_grad_enabled(False)
|
| 69 |
+
learn_mu = args.learn_mu
|
| 70 |
+
sitf2_depth = args.depth # Save SiTF2 depth before it gets overwritten
|
| 71 |
+
|
| 72 |
+
# Setup DDP:
|
| 73 |
+
dist.init_process_group("nccl")
|
| 74 |
+
rank = dist.get_rank()
|
| 75 |
+
device = rank % torch.cuda.device_count()
|
| 76 |
+
seed = args.global_seed * dist.get_world_size() + rank
|
| 77 |
+
torch.manual_seed(seed)
|
| 78 |
+
torch.cuda.set_device(device)
|
| 79 |
+
print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
|
| 80 |
+
|
| 81 |
+
if args.ckpt is None:
|
| 82 |
+
assert args.model == "SiT-XL/2", "Only SiT-XL/2 models are available for auto-download."
|
| 83 |
+
assert args.image_size in [256, 512]
|
| 84 |
+
assert args.num_classes == 1000
|
| 85 |
+
assert args.image_size == 256, "512x512 models are not yet available for auto-download." # remove this line when 512x512 models are available
|
| 86 |
+
learn_sigma = args.image_size == 256
|
| 87 |
+
else:
|
| 88 |
+
learn_sigma = False
|
| 89 |
+
|
| 90 |
+
# Load SiTF1 and SiTF2 models and create CombinedModel
|
| 91 |
+
from models import SiTF1, SiTF2, CombinedModel
|
| 92 |
+
latent_size = args.image_size // 8
|
| 93 |
+
|
| 94 |
+
# Get model configuration based on args.model
|
| 95 |
+
model_name = args.model
|
| 96 |
+
if 'XL' in model_name:
|
| 97 |
+
hidden_size, depth, num_heads = 1152, 28, 16
|
| 98 |
+
elif 'L' in model_name:
|
| 99 |
+
hidden_size, depth, num_heads = 1024, 24, 16
|
| 100 |
+
elif 'B' in model_name:
|
| 101 |
+
hidden_size, depth, num_heads = 768, 12, 12
|
| 102 |
+
elif 'S' in model_name:
|
| 103 |
+
hidden_size, depth, num_heads = 384, 12, 6
|
| 104 |
+
else:
|
| 105 |
+
# Default fallback
|
| 106 |
+
hidden_size, depth, num_heads = 768, 12, 12
|
| 107 |
+
|
| 108 |
+
# Extract patch size from model name like 'SiT-XL/2' -> patch_size = 2
|
| 109 |
+
patch_size = int(model_name.split('/')[-1])
|
| 110 |
+
|
| 111 |
+
# Load SiTF1
|
| 112 |
+
sitf1 = SiTF1(
|
| 113 |
+
input_size=latent_size,
|
| 114 |
+
patch_size=patch_size,
|
| 115 |
+
in_channels=4,
|
| 116 |
+
hidden_size=hidden_size,
|
| 117 |
+
depth=depth,
|
| 118 |
+
num_heads=num_heads,
|
| 119 |
+
mlp_ratio=4.0,
|
| 120 |
+
class_dropout_prob=0.1,
|
| 121 |
+
num_classes=args.num_classes,
|
| 122 |
+
learn_sigma=False
|
| 123 |
+
).to(device)
|
| 124 |
+
sitf1_state_raw = find_model(args.ckpt)
|
| 125 |
+
# find_model now returns ema if available, or the full checkpoint
|
| 126 |
+
# Extract the actual state_dict to use for both sitf1 and base_model
|
| 127 |
+
if isinstance(sitf1_state_raw, dict) and "model" in sitf1_state_raw:
|
| 128 |
+
sitf1_state = sitf1_state_raw["model"]
|
| 129 |
+
else:
|
| 130 |
+
# sitf1_state_raw is already a state_dict (either ema or direct model state)
|
| 131 |
+
sitf1_state = sitf1_state_raw
|
| 132 |
+
sitf1.load_state_dict(sitf1_state)
|
| 133 |
+
sitf1.eval()
|
| 134 |
+
|
| 135 |
+
# For sampling, we can use sitf1 directly instead of creating a separate sit model
|
| 136 |
+
# since sitf1 and sit have the same architecture and weights
|
| 137 |
+
|
| 138 |
+
# Load SiTF2 with the same architecture parameters as SiTF1 for compatibility
|
| 139 |
+
sitf2 = SiTF2(
|
| 140 |
+
input_size=latent_size,
|
| 141 |
+
hidden_size=hidden_size, # Use the same hidden_size as SiTF1
|
| 142 |
+
out_channels=8,
|
| 143 |
+
patch_size=patch_size, # Use the same patch_size as SiTF1
|
| 144 |
+
num_heads=num_heads, # Use the same num_heads as SiTF1
|
| 145 |
+
mlp_ratio=4.0,
|
| 146 |
+
depth=sitf2_depth, # Use the depth specified by command line argument (not the model's default depth)
|
| 147 |
+
learn_sigma=True,
|
| 148 |
+
num_classes=args.num_classes,
|
| 149 |
+
learn_mu=learn_mu
|
| 150 |
+
).to(device)
|
| 151 |
+
sitf2 = DDP(sitf2, device_ids=[device])
|
| 152 |
+
sitf2_state = find_model(args.sitf2_ckpt)
|
| 153 |
+
# Fix state dict keys to match DDP model
|
| 154 |
+
sitf2_state_fixed = fix_state_dict_for_ddp(sitf2_state)
|
| 155 |
+
try:
|
| 156 |
+
sitf2.load_state_dict(sitf2_state_fixed)
|
| 157 |
+
except Exception as e:
|
| 158 |
+
print(f"Error loading state dict: {e}")
|
| 159 |
+
# Try loading with strict=False as fallback
|
| 160 |
+
sitf2.load_state_dict(sitf2_state_fixed, strict=False)
|
| 161 |
+
sitf2.eval()
|
| 162 |
+
# CombinedModel
|
| 163 |
+
|
| 164 |
+
combined_model = CombinedModel(sitf1, sitf2).to(device)
|
| 165 |
+
sitf2.eval()
|
| 166 |
+
combined_model.eval()
|
| 167 |
+
|
| 168 |
+
# Use SiT_models factory function to create the base model, same as in SiT_clean
|
| 169 |
+
# This ensures correct model configuration
|
| 170 |
+
# Use learn_sigma=False to match sitf1 configuration
|
| 171 |
+
base_model = SiT_models[args.model](
|
| 172 |
+
input_size=latent_size,
|
| 173 |
+
num_classes=args.num_classes,
|
| 174 |
+
learn_sigma=False, # Match sitf1's learn_sigma=False
|
| 175 |
+
).to(device)
|
| 176 |
+
# Load the checkpoint (same as sitf1) - use the exact same state_dict
|
| 177 |
+
base_model.load_state_dict(sitf1_state)
|
| 178 |
+
base_model.eval()
|
| 179 |
+
|
| 180 |
+
# Determine if CFG will be used (needed for combined_sampling_model function)
|
| 181 |
+
assert args.cfg_scale >= 1.0, "In almost all cases, cfg_scale be >= 1.0"
|
| 182 |
+
using_cfg = args.cfg_scale > 1.0
|
| 183 |
+
|
| 184 |
+
# There are repeated calculations in the middle,
|
| 185 |
+
# which will cause Flops to double. A simplified version will be released later
|
| 186 |
+
def combined_sampling_model(x, t, y=None, **kwargs):
|
| 187 |
+
with torch.no_grad():
|
| 188 |
+
# Handle CFG same as in SiT_clean/sample_ddp.py
|
| 189 |
+
if using_cfg and 'cfg_scale' in kwargs:
|
| 190 |
+
# Use forward_with_cfg when CFG is enabled
|
| 191 |
+
sit_out = base_model.forward_with_cfg(x, t, y, kwargs['cfg_scale'])
|
| 192 |
+
else:
|
| 193 |
+
# Use regular forward when CFG is disabled
|
| 194 |
+
sit_out = base_model.forward(x, t, y)
|
| 195 |
+
# If use_sitf2_before_t05 is True, only use sitf2 when t < threshold
|
| 196 |
+
if args.use_sitf2:
|
| 197 |
+
if args.use_sitf2_before_t05:
|
| 198 |
+
# t is a tensor, check which samples have t < threshold
|
| 199 |
+
# Create a mask: 1.0 where t < threshold, 0.0 otherwise
|
| 200 |
+
mask = (t < args.sitf2_threshold).float()
|
| 201 |
+
# Compute sitf2 output for all samples
|
| 202 |
+
combined_out = combined_model.forward(x, t, y)
|
| 203 |
+
# Expand mask to match the spatial dimensions of combined_out
|
| 204 |
+
# combined_out shape is (batch, channels, height, width)
|
| 205 |
+
while len(mask.shape) < len(combined_out.shape):
|
| 206 |
+
mask = mask.unsqueeze(-1)
|
| 207 |
+
# Broadcast mask to match combined_out shape
|
| 208 |
+
mask = mask.expand_as(combined_out)
|
| 209 |
+
# Only use sitf2 output where t < threshold
|
| 210 |
+
combined_out = combined_out * mask
|
| 211 |
+
# Combine sit_out and masked combined_out
|
| 212 |
+
return sit_out + combined_out
|
| 213 |
+
else:
|
| 214 |
+
# Default behavior: only use base model output
|
| 215 |
+
return sit_out
|
| 216 |
+
else:
|
| 217 |
+
# Default behavior: only use base model output
|
| 218 |
+
return sit_out
|
| 219 |
+
|
| 220 |
+
transport = create_transport(
|
| 221 |
+
args.path_type,
|
| 222 |
+
args.prediction,
|
| 223 |
+
args.loss_weight,
|
| 224 |
+
args.train_eps,
|
| 225 |
+
args.sample_eps
|
| 226 |
+
)
|
| 227 |
+
sampler = Sampler(transport)
|
| 228 |
+
if mode == "ODE":
|
| 229 |
+
if args.likelihood:
|
| 230 |
+
assert args.cfg_scale == 1, "Likelihood is incompatible with guidance"
|
| 231 |
+
sample_fn = sampler.sample_ode_likelihood(
|
| 232 |
+
sampling_method=args.sampling_method,
|
| 233 |
+
num_steps=args.num_sampling_steps,
|
| 234 |
+
atol=args.atol,
|
| 235 |
+
rtol=args.rtol,
|
| 236 |
+
)
|
| 237 |
+
else:
|
| 238 |
+
sample_fn = sampler.sample_ode(
|
| 239 |
+
sampling_method=args.sampling_method,
|
| 240 |
+
num_steps=args.num_sampling_steps,
|
| 241 |
+
atol=args.atol,
|
| 242 |
+
rtol=args.rtol,
|
| 243 |
+
reverse=args.reverse
|
| 244 |
+
)
|
| 245 |
+
elif mode == "SDE":
|
| 246 |
+
sample_fn = sampler.sample_sde(
|
| 247 |
+
sampling_method=args.sampling_method,
|
| 248 |
+
diffusion_form=args.diffusion_form,
|
| 249 |
+
diffusion_norm=args.diffusion_norm,
|
| 250 |
+
last_step=args.last_step,
|
| 251 |
+
last_step_size=args.last_step_size,
|
| 252 |
+
num_steps=args.num_sampling_steps,
|
| 253 |
+
)
|
| 254 |
+
vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device)
|
| 255 |
+
|
| 256 |
+
# Create folder to save samples:
|
| 257 |
+
model_string_name = args.model.replace("/", "-")
|
| 258 |
+
ckpt_string_name = os.path.basename(args.ckpt).replace(".pt", "") if args.ckpt else "pretrained"
|
| 259 |
+
sitf2_ckpt_string_name = os.path.basename(args.sitf2_ckpt).replace(".pt", "") if args.ckpt else "pretrained"
|
| 260 |
+
if mode == "ODE":
|
| 261 |
+
folder_name = f"{sitf2_ckpt_string_name}-{ckpt_string_name}-" \
|
| 262 |
+
f"cfg-{args.cfg_scale}-{args.per_proc_batch_size}-"\
|
| 263 |
+
f"{mode}-{args.num_sampling_steps}-{args.sampling_method}"
|
| 264 |
+
elif mode == "SDE":
|
| 265 |
+
# Add threshold info to folder name if use_sitf2_before_t05 is enabled
|
| 266 |
+
threshold_suffix = f"-threshold-{args.sitf2_threshold}" if args.use_sitf2_before_t05 else ""
|
| 267 |
+
if learn_mu:
|
| 268 |
+
folder_name = f"depth-mu-{sitf2_depth}{threshold_suffix}-{sitf2_ckpt_string_name}-{ckpt_string_name}-" \
|
| 269 |
+
f"cfg-{args.cfg_scale}-{args.per_proc_batch_size}-"\
|
| 270 |
+
f"{mode}-{args.num_sampling_steps}-{args.sampling_method}-"\
|
| 271 |
+
f"{args.diffusion_form}-{args.last_step}-{args.last_step_size}"
|
| 272 |
+
else:
|
| 273 |
+
folder_name = f"depth-sigma-{sitf2_depth}{threshold_suffix}-{sitf2_ckpt_string_name}-{ckpt_string_name}-" \
|
| 274 |
+
f"cfg-{args.cfg_scale}-{args.per_proc_batch_size}-"\
|
| 275 |
+
f"{mode}-{args.num_sampling_steps}-{args.sampling_method}-"\
|
| 276 |
+
f"{args.diffusion_form}-{args.last_step}-{args.last_step_size}"
|
| 277 |
+
sample_folder_dir = f"{args.sample_dir}/{folder_name}"
|
| 278 |
+
if rank == 0:
|
| 279 |
+
os.makedirs(sample_folder_dir, exist_ok=True)
|
| 280 |
+
print(f"Saving .png samples at {sample_folder_dir}")
|
| 281 |
+
dist.barrier()
|
| 282 |
+
|
| 283 |
+
# Figure out how many samples we need to generate on each GPU and how many iterations we need to run:
|
| 284 |
+
n = args.per_proc_batch_size
|
| 285 |
+
global_batch_size = n * dist.get_world_size()
|
| 286 |
+
# To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples:
|
| 287 |
+
num_samples = len([name for name in os.listdir(sample_folder_dir) if (os.path.isfile(os.path.join(sample_folder_dir, name)) and ".png" in name)])
|
| 288 |
+
total_samples = int(math.ceil(args.num_fid_samples / global_batch_size) * global_batch_size)
|
| 289 |
+
if rank == 0:
|
| 290 |
+
print(f"Total number of images that will be sampled: {total_samples}")
|
| 291 |
+
assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size"
|
| 292 |
+
samples_needed_this_gpu = int(total_samples // dist.get_world_size())
|
| 293 |
+
assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size"
|
| 294 |
+
iterations = int(samples_needed_this_gpu // n)
|
| 295 |
+
done_iterations = int( int(num_samples // dist.get_world_size()) // n)
|
| 296 |
+
pbar = range(iterations)
|
| 297 |
+
pbar = tqdm(pbar) if rank == 0 else pbar
|
| 298 |
+
total = 0
|
| 299 |
+
|
| 300 |
+
for i in pbar:
|
| 301 |
+
# Sample inputs:
|
| 302 |
+
z = torch.randn(n, base_model.in_channels, latent_size, latent_size, device=device)
|
| 303 |
+
y = torch.randint(0, args.num_classes, (n,), device=device)
|
| 304 |
+
# Setup classifier-free guidance:
|
| 305 |
+
if using_cfg:
|
| 306 |
+
z = torch.cat([z, z], 0)
|
| 307 |
+
y_null = torch.tensor([1000] * n, device=device)
|
| 308 |
+
y = torch.cat([y, y_null], 0)
|
| 309 |
+
model_kwargs = dict(y=y, cfg_scale=args.cfg_scale)
|
| 310 |
+
else:
|
| 311 |
+
model_kwargs = dict(y=y)
|
| 312 |
+
samples = sample_fn(z, combined_sampling_model, **model_kwargs)[-1]
|
| 313 |
+
if using_cfg:
|
| 314 |
+
samples, _ = samples.chunk(2, dim=0) # Remove null class samples
|
| 315 |
+
samples = vae.decode(samples / 0.18215).sample
|
| 316 |
+
samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
|
| 317 |
+
# Save samples to disk as individual .png files
|
| 318 |
+
for i, sample in enumerate(samples):
|
| 319 |
+
index = i * dist.get_world_size() + rank + total
|
| 320 |
+
Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png")
|
| 321 |
+
total += global_batch_size
|
| 322 |
+
dist.barrier()
|
| 323 |
+
|
| 324 |
+
# Make sure all processes have finished saving their samples before attempting to convert to .npz
|
| 325 |
+
dist.barrier()
|
| 326 |
+
if rank == 0:
|
| 327 |
+
create_npz_from_sample_folder(sample_folder_dir, args.num_fid_samples)
|
| 328 |
+
print("Done.")
|
| 329 |
+
dist.barrier()
|
| 330 |
+
dist.destroy_process_group()
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
if __name__ == "__main__":
|
| 334 |
+
|
| 335 |
+
parser = argparse.ArgumentParser()
|
| 336 |
+
|
| 337 |
+
if len(sys.argv) < 2:
|
| 338 |
+
print("Usage: program.py <mode> [options]")
|
| 339 |
+
sys.exit(1)
|
| 340 |
+
|
| 341 |
+
mode = sys.argv[1]
|
| 342 |
+
|
| 343 |
+
assert mode[:2] != "--", "Usage: program.py <mode> [options]"
|
| 344 |
+
assert mode in ["ODE", "SDE"], "Invalid mode. Please choose 'ODE' or 'SDE'"
|
| 345 |
+
|
| 346 |
+
parser.add_argument("--model", type=str, choices=list(SiT_models.keys()), default="SiT-XL/2")
|
| 347 |
+
parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema")
|
| 348 |
+
parser.add_argument("--sample-dir", type=str, default="samples")
|
| 349 |
+
parser.add_argument("--per-proc-batch-size", type=int, default=64)
|
| 350 |
+
parser.add_argument("--num-fid-samples", type=int, default=50_000)
|
| 351 |
+
parser.add_argument("--image-size", type=int, choices=[256, 512], default=256)
|
| 352 |
+
parser.add_argument("--num-classes", type=int, default=1000)
|
| 353 |
+
parser.add_argument("--cfg-scale", type=float, default=1.0)
|
| 354 |
+
parser.add_argument("--num-sampling-steps", type=int, default=100)
|
| 355 |
+
parser.add_argument("--global-seed", type=int, default=0)
|
| 356 |
+
parser.add_argument("--tf32", action=argparse.BooleanOptionalAction, default=True,
|
| 357 |
+
help="By default, use TF32 matmuls. This massively accelerates sampling on Ampere GPUs.")
|
| 358 |
+
parser.add_argument("--ckpt", type=str, default=None,
|
| 359 |
+
help="Optional path to a SiT checkpoint.")
|
| 360 |
+
parser.add_argument("--sitf2-ckpt", type=str, required=True, help="Path to SiTF2 checkpoint")
|
| 361 |
+
parser.add_argument("--learn-mu", action=argparse.BooleanOptionalAction, default=True,
|
| 362 |
+
help="Whether to learn mu parameter")
|
| 363 |
+
parser.add_argument("--depth", type=int, default=1,
|
| 364 |
+
help="Depth parameter for SiTF2 model")
|
| 365 |
+
parser.add_argument("--use-sitf2", action=argparse.BooleanOptionalAction, default=True,
|
| 366 |
+
help="Only use SiTF2 output when t < threshold, otherwise use only SiT")
|
| 367 |
+
parser.add_argument("--use-sitf2-before-t05", action=argparse.BooleanOptionalAction, default=False,
|
| 368 |
+
help="Only use SiTF2 output when t < threshold, otherwise use only SiT")
|
| 369 |
+
parser.add_argument("--sitf2-threshold", type=float, default=0.5,
|
| 370 |
+
help="Time threshold for using SiTF2 output (default: 0.5). Only effective when --use-sitf2-before-t05 is True")
|
| 371 |
+
parse_transport_args(parser)
|
| 372 |
+
if mode == "ODE":
|
| 373 |
+
parse_ode_args(parser)
|
| 374 |
+
# Further processing for ODE
|
| 375 |
+
elif mode == "SDE":
|
| 376 |
+
parse_sde_args(parser)
|
| 377 |
+
# Further processing for SDE
|
| 378 |
+
|
| 379 |
+
args = parser.parse_known_args()[0]
|
| 380 |
+
main(mode, args)
|
Rectified_Noise/GVP-Disp/train_utils.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def none_or_str(value):
|
| 2 |
+
if value == 'None':
|
| 3 |
+
return None
|
| 4 |
+
return value
|
| 5 |
+
|
| 6 |
+
def parse_transport_args(parser):
|
| 7 |
+
group = parser.add_argument_group("Transport arguments")
|
| 8 |
+
group.add_argument("--path-type", type=str, default="GVP", choices=["Linear", "GVP", "VP"],
|
| 9 |
+
help="Type of path for loss calculation. This parameter directly affects the loss form used during training. "
|
| 10 |
+
"Choices: Linear (linear interpolation path), GVP (Geodesic Velocity Path), VP (Velocity Path). "
|
| 11 |
+
"The path_type determines how the transport loss is computed in training_losses().")
|
| 12 |
+
group.add_argument("--prediction", type=str, default="velocity", choices=["velocity", "score", "noise"])
|
| 13 |
+
group.add_argument("--loss-weight", type=none_or_str, default=None, choices=[None, "velocity", "likelihood"])
|
| 14 |
+
group.add_argument("--sample-eps", type=float, default=0.0)
|
| 15 |
+
group.add_argument("--train-eps", type=float, default=0.0)
|
| 16 |
+
|
| 17 |
+
def parse_ode_args(parser):
|
| 18 |
+
group = parser.add_argument_group("ODE arguments")
|
| 19 |
+
group.add_argument("--sampling-method", type=str, default="dopri5", help="blackbox ODE solver methods; for full list check https://github.com/rtqichen/torchdiffeq")
|
| 20 |
+
group.add_argument("--atol", type=float, default=1e-6, help="Absolute tolerance")
|
| 21 |
+
group.add_argument("--rtol", type=float, default=1e-3, help="Relative tolerance")
|
| 22 |
+
group.add_argument("--reverse", action="store_true")
|
| 23 |
+
group.add_argument("--likelihood", action="store_true")
|
| 24 |
+
|
| 25 |
+
def parse_sde_args(parser):
|
| 26 |
+
group = parser.add_argument_group("SDE arguments")
|
| 27 |
+
group.add_argument("--sampling-method", type=str, default="Euler", choices=["Euler", "Heun"])
|
| 28 |
+
group.add_argument("--diffusion-form", type=str, default="sigma", \
|
| 29 |
+
choices=["constant", "SBDM", "sigma", "linear", "decreasing", "increasing-decreasing"],\
|
| 30 |
+
help="form of diffusion coefficient in the SDE")
|
| 31 |
+
group.add_argument("--diffusion-norm", type=float, default=1.0)
|
| 32 |
+
group.add_argument("--last-step", type=none_or_str, default="Mean", choices=[None, "Mean", "Tweedie", "Euler"],\
|
| 33 |
+
help="form of last step taken in the SDE")
|
| 34 |
+
group.add_argument("--last-step-size", type=float, default=0.04, \
|
| 35 |
+
help="size of the last step taken")
|
Rectified_Noise/GVP-Disp/w_training1_VP.log
ADDED
|
@@ -0,0 +1,628 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
nohup: ignoring input
|
| 2 |
+
[NOTICE] The application is pending for GPU resource in asynchronous queue. The longest waiting time in queue is 1800 seconds.
|
| 3 |
+
Starting rank=0, seed=0, world_size=1.
|
| 4 |
+
[[34m2026-02-01 14:09:25[0m] Experiment directory created at results_256_vp/depth-mu-2-000-SiT-XL-2-VP-velocity-None
|
| 5 |
+
[[34m2026-02-01 14:09:57[0m] Combined_model Parameters: 729,629,632
|
| 6 |
+
[[34m2026-02-01 14:09:57[0m] Total trainable parameters: 53,910,176
|
| 7 |
+
[[34m2026-02-01 14:09:59[0m] Dataset contains 1,281,167 images (/gemini/platform/public/zhaozy/hzh/datasets/Imagenet/train/)
|
| 8 |
+
[[34m2026-02-01 14:09:59[0m] Training for 100000 epochs...
|
| 9 |
+
[[34m2026-02-01 14:09:59[0m] Beginning epoch 0...
|
| 10 |
+
[[34m2026-02-01 14:11:18[0m] (step=0000100) Train Loss: 2.7011, Train Steps/Sec: 1.27
|
| 11 |
+
[[34m2026-02-01 14:12:34[0m] (step=0000200) Train Loss: 1.9056, Train Steps/Sec: 1.32
|
| 12 |
+
[[34m2026-02-01 14:13:49[0m] (step=0000300) Train Loss: 1.7930, Train Steps/Sec: 1.32
|
| 13 |
+
[[34m2026-02-01 14:15:05[0m] (step=0000400) Train Loss: 2.0316, Train Steps/Sec: 1.32
|
| 14 |
+
[[34m2026-02-01 14:16:21[0m] (step=0000500) Train Loss: 1.8412, Train Steps/Sec: 1.32
|
| 15 |
+
[[34m2026-02-01 14:17:37[0m] (step=0000600) Train Loss: 1.8505, Train Steps/Sec: 1.32
|
| 16 |
+
[[34m2026-02-01 14:18:53[0m] (step=0000700) Train Loss: 1.8542, Train Steps/Sec: 1.32
|
| 17 |
+
[[34m2026-02-01 14:20:09[0m] (step=0000800) Train Loss: 1.8904, Train Steps/Sec: 1.32
|
| 18 |
+
[[34m2026-02-01 14:21:25[0m] (step=0000900) Train Loss: 1.9280, Train Steps/Sec: 1.32
|
| 19 |
+
[[34m2026-02-01 14:22:41[0m] (step=0001000) Train Loss: 1.8453, Train Steps/Sec: 1.32
|
| 20 |
+
[[34m2026-02-01 14:23:57[0m] (step=0001100) Train Loss: 1.8745, Train Steps/Sec: 1.32
|
| 21 |
+
[[34m2026-02-01 14:25:13[0m] (step=0001200) Train Loss: 1.8410, Train Steps/Sec: 1.32
|
| 22 |
+
[[34m2026-02-01 14:26:29[0m] (step=0001300) Train Loss: 1.8445, Train Steps/Sec: 1.32
|
| 23 |
+
[[34m2026-02-01 14:27:44[0m] (step=0001400) Train Loss: 1.8173, Train Steps/Sec: 1.32
|
| 24 |
+
[[34m2026-02-01 14:29:00[0m] (step=0001500) Train Loss: 3.5917, Train Steps/Sec: 1.32
|
| 25 |
+
[[34m2026-02-01 14:30:16[0m] (step=0001600) Train Loss: 1.8747, Train Steps/Sec: 1.32
|
| 26 |
+
[[34m2026-02-01 14:31:32[0m] (step=0001700) Train Loss: 1.8092, Train Steps/Sec: 1.32
|
| 27 |
+
[[34m2026-02-01 14:32:48[0m] (step=0001800) Train Loss: 1.8720, Train Steps/Sec: 1.32
|
| 28 |
+
[[34m2026-02-01 14:34:04[0m] (step=0001900) Train Loss: 1.8186, Train Steps/Sec: 1.32
|
| 29 |
+
[[34m2026-02-01 14:35:20[0m] (step=0002000) Train Loss: 1.9034, Train Steps/Sec: 1.32
|
| 30 |
+
[[34m2026-02-01 14:36:36[0m] (step=0002100) Train Loss: 1.8993, Train Steps/Sec: 1.32
|
| 31 |
+
[[34m2026-02-01 14:37:52[0m] (step=0002200) Train Loss: 1.8499, Train Steps/Sec: 1.32
|
| 32 |
+
[[34m2026-02-01 14:39:08[0m] (step=0002300) Train Loss: 2.1165, Train Steps/Sec: 1.32
|
| 33 |
+
[[34m2026-02-01 14:40:24[0m] (step=0002400) Train Loss: 1.8346, Train Steps/Sec: 1.32
|
| 34 |
+
[[34m2026-02-01 14:41:40[0m] (step=0002500) Train Loss: 1.7744, Train Steps/Sec: 1.32
|
| 35 |
+
[[34m2026-02-01 14:42:56[0m] (step=0002600) Train Loss: 1.8164, Train Steps/Sec: 1.32
|
| 36 |
+
[[34m2026-02-01 14:44:12[0m] (step=0002700) Train Loss: 1.8115, Train Steps/Sec: 1.32
|
| 37 |
+
[[34m2026-02-01 14:45:28[0m] (step=0002800) Train Loss: 1.8150, Train Steps/Sec: 1.32
|
| 38 |
+
[[34m2026-02-01 14:46:44[0m] (step=0002900) Train Loss: 1.8270, Train Steps/Sec: 1.32
|
| 39 |
+
[[34m2026-02-01 14:48:00[0m] (step=0003000) Train Loss: 1.9181, Train Steps/Sec: 1.32
|
| 40 |
+
[[34m2026-02-01 14:49:16[0m] (step=0003100) Train Loss: 1.9040, Train Steps/Sec: 1.32
|
| 41 |
+
[[34m2026-02-01 14:50:31[0m] (step=0003200) Train Loss: 2.2287, Train Steps/Sec: 1.32
|
| 42 |
+
[[34m2026-02-01 14:51:47[0m] (step=0003300) Train Loss: 2.0059, Train Steps/Sec: 1.32
|
| 43 |
+
[[34m2026-02-01 14:53:03[0m] (step=0003400) Train Loss: 1.8687, Train Steps/Sec: 1.32
|
| 44 |
+
[[34m2026-02-01 14:54:19[0m] (step=0003500) Train Loss: 1.9185, Train Steps/Sec: 1.32
|
| 45 |
+
[[34m2026-02-01 14:55:35[0m] (step=0003600) Train Loss: 1.9162, Train Steps/Sec: 1.32
|
| 46 |
+
[[34m2026-02-01 14:56:51[0m] (step=0003700) Train Loss: 2.0918, Train Steps/Sec: 1.32
|
| 47 |
+
[[34m2026-02-01 14:58:07[0m] (step=0003800) Train Loss: 2.5750, Train Steps/Sec: 1.32
|
| 48 |
+
[[34m2026-02-01 14:59:23[0m] (step=0003900) Train Loss: 1.8959, Train Steps/Sec: 1.32
|
| 49 |
+
[[34m2026-02-01 15:00:39[0m] (step=0004000) Train Loss: 1.8935, Train Steps/Sec: 1.32
|
| 50 |
+
[[34m2026-02-01 15:01:55[0m] (step=0004100) Train Loss: 1.8143, Train Steps/Sec: 1.32
|
| 51 |
+
[[34m2026-02-01 15:03:11[0m] (step=0004200) Train Loss: 2.0503, Train Steps/Sec: 1.32
|
| 52 |
+
[[34m2026-02-01 15:04:27[0m] (step=0004300) Train Loss: 1.8916, Train Steps/Sec: 1.32
|
| 53 |
+
[[34m2026-02-01 15:05:43[0m] (step=0004400) Train Loss: 2.1279, Train Steps/Sec: 1.32
|
| 54 |
+
[[34m2026-02-01 15:06:59[0m] (step=0004500) Train Loss: 1.8331, Train Steps/Sec: 1.32
|
| 55 |
+
[[34m2026-02-01 15:08:15[0m] (step=0004600) Train Loss: 1.8969, Train Steps/Sec: 1.32
|
| 56 |
+
[[34m2026-02-01 15:09:31[0m] (step=0004700) Train Loss: 1.8220, Train Steps/Sec: 1.32
|
| 57 |
+
[[34m2026-02-01 15:10:47[0m] (step=0004800) Train Loss: 1.8862, Train Steps/Sec: 1.32
|
| 58 |
+
[[34m2026-02-01 15:12:03[0m] (step=0004900) Train Loss: 1.9553, Train Steps/Sec: 1.32
|
| 59 |
+
[[34m2026-02-01 15:13:19[0m] (step=0005000) Train Loss: 1.8549, Train Steps/Sec: 1.31
|
| 60 |
+
[[34m2026-02-01 15:14:35[0m] (step=0005100) Train Loss: 1.9343, Train Steps/Sec: 1.32
|
| 61 |
+
[[34m2026-02-01 15:15:51[0m] (step=0005200) Train Loss: 1.9899, Train Steps/Sec: 1.32
|
| 62 |
+
[[34m2026-02-01 15:17:07[0m] (step=0005300) Train Loss: 1.9115, Train Steps/Sec: 1.32
|
| 63 |
+
[[34m2026-02-01 15:18:23[0m] (step=0005400) Train Loss: 2.2117, Train Steps/Sec: 1.32
|
| 64 |
+
[[34m2026-02-01 15:19:39[0m] (step=0005500) Train Loss: 1.9424, Train Steps/Sec: 1.32
|
| 65 |
+
[[34m2026-02-01 15:20:55[0m] (step=0005600) Train Loss: 1.8367, Train Steps/Sec: 1.32
|
| 66 |
+
[[34m2026-02-01 15:22:11[0m] (step=0005700) Train Loss: 1.8696, Train Steps/Sec: 1.32
|
| 67 |
+
[[34m2026-02-01 15:23:27[0m] (step=0005800) Train Loss: 2.2085, Train Steps/Sec: 1.32
|
| 68 |
+
[[34m2026-02-01 15:24:43[0m] (step=0005900) Train Loss: 1.8185, Train Steps/Sec: 1.32
|
| 69 |
+
[[34m2026-02-01 15:25:59[0m] (step=0006000) Train Loss: 1.8452, Train Steps/Sec: 1.32
|
| 70 |
+
[[34m2026-02-01 15:27:15[0m] (step=0006100) Train Loss: 1.8141, Train Steps/Sec: 1.32
|
| 71 |
+
[[34m2026-02-01 15:28:31[0m] (step=0006200) Train Loss: 2.4398, Train Steps/Sec: 1.32
|
| 72 |
+
[[34m2026-02-01 15:29:47[0m] (step=0006300) Train Loss: 1.9160, Train Steps/Sec: 1.32
|
| 73 |
+
[[34m2026-02-01 15:31:03[0m] (step=0006400) Train Loss: 1.9920, Train Steps/Sec: 1.32
|
| 74 |
+
[[34m2026-02-01 15:32:19[0m] (step=0006500) Train Loss: 1.8726, Train Steps/Sec: 1.32
|
| 75 |
+
[[34m2026-02-01 15:33:35[0m] (step=0006600) Train Loss: 1.9302, Train Steps/Sec: 1.32
|
| 76 |
+
[[34m2026-02-01 15:34:51[0m] (step=0006700) Train Loss: 1.8886, Train Steps/Sec: 1.32
|
| 77 |
+
[[34m2026-02-01 15:36:07[0m] (step=0006800) Train Loss: 1.8492, Train Steps/Sec: 1.32
|
| 78 |
+
[[34m2026-02-01 15:37:23[0m] (step=0006900) Train Loss: 2.0008, Train Steps/Sec: 1.32
|
| 79 |
+
[[34m2026-02-01 15:38:39[0m] (step=0007000) Train Loss: 1.9791, Train Steps/Sec: 1.32
|
| 80 |
+
[[34m2026-02-01 15:39:55[0m] (step=0007100) Train Loss: 1.9221, Train Steps/Sec: 1.32
|
| 81 |
+
[[34m2026-02-01 15:41:11[0m] (step=0007200) Train Loss: 1.8893, Train Steps/Sec: 1.32
|
| 82 |
+
[[34m2026-02-01 15:42:27[0m] (step=0007300) Train Loss: 1.8739, Train Steps/Sec: 1.32
|
| 83 |
+
[[34m2026-02-01 15:43:43[0m] (step=0007400) Train Loss: 2.6370, Train Steps/Sec: 1.32
|
| 84 |
+
[[34m2026-02-01 15:44:59[0m] (step=0007500) Train Loss: 2.1859, Train Steps/Sec: 1.32
|
| 85 |
+
[[34m2026-02-01 15:46:15[0m] (step=0007600) Train Loss: 1.8067, Train Steps/Sec: 1.32
|
| 86 |
+
[[34m2026-02-01 15:47:31[0m] (step=0007700) Train Loss: 1.8996, Train Steps/Sec: 1.32
|
| 87 |
+
[[34m2026-02-01 15:48:47[0m] (step=0007800) Train Loss: 1.9468, Train Steps/Sec: 1.32
|
| 88 |
+
[[34m2026-02-01 15:50:03[0m] (step=0007900) Train Loss: 1.8925, Train Steps/Sec: 1.32
|
| 89 |
+
[[34m2026-02-01 15:51:19[0m] (step=0008000) Train Loss: 1.7844, Train Steps/Sec: 1.32
|
| 90 |
+
[[34m2026-02-01 15:52:35[0m] (step=0008100) Train Loss: 1.9823, Train Steps/Sec: 1.32
|
| 91 |
+
[[34m2026-02-01 15:53:51[0m] (step=0008200) Train Loss: 1.9363, Train Steps/Sec: 1.32
|
| 92 |
+
[[34m2026-02-01 15:55:07[0m] (step=0008300) Train Loss: 1.8508, Train Steps/Sec: 1.32
|
| 93 |
+
[[34m2026-02-01 15:56:22[0m] (step=0008400) Train Loss: 1.9048, Train Steps/Sec: 1.32
|
| 94 |
+
[[34m2026-02-01 15:57:38[0m] (step=0008500) Train Loss: 1.8955, Train Steps/Sec: 1.32
|
| 95 |
+
[[34m2026-02-01 15:58:54[0m] (step=0008600) Train Loss: 1.8585, Train Steps/Sec: 1.32
|
| 96 |
+
[[34m2026-02-01 16:00:10[0m] (step=0008700) Train Loss: 1.8621, Train Steps/Sec: 1.32
|
| 97 |
+
[[34m2026-02-01 16:01:26[0m] (step=0008800) Train Loss: 1.8826, Train Steps/Sec: 1.32
|
| 98 |
+
[[34m2026-02-01 16:02:43[0m] (step=0008900) Train Loss: 1.9289, Train Steps/Sec: 1.31
|
| 99 |
+
[[34m2026-02-01 16:03:59[0m] (step=0009000) Train Loss: 1.9667, Train Steps/Sec: 1.32
|
| 100 |
+
[[34m2026-02-01 16:05:15[0m] (step=0009100) Train Loss: 2.1871, Train Steps/Sec: 1.32
|
| 101 |
+
[[34m2026-02-01 16:06:31[0m] (step=0009200) Train Loss: 1.8651, Train Steps/Sec: 1.32
|
| 102 |
+
[[34m2026-02-01 16:07:47[0m] (step=0009300) Train Loss: 1.9620, Train Steps/Sec: 1.32
|
| 103 |
+
[[34m2026-02-01 16:09:03[0m] (step=0009400) Train Loss: 1.8992, Train Steps/Sec: 1.32
|
| 104 |
+
[[34m2026-02-01 16:10:18[0m] (step=0009500) Train Loss: 1.8620, Train Steps/Sec: 1.32
|
| 105 |
+
[[34m2026-02-01 16:11:34[0m] (step=0009600) Train Loss: 1.9782, Train Steps/Sec: 1.32
|
| 106 |
+
[[34m2026-02-01 16:12:50[0m] (step=0009700) Train Loss: 2.3364, Train Steps/Sec: 1.32
|
| 107 |
+
[[34m2026-02-01 16:14:06[0m] (step=0009800) Train Loss: 1.8309, Train Steps/Sec: 1.32
|
| 108 |
+
[[34m2026-02-01 16:15:22[0m] (step=0009900) Train Loss: 2.5777, Train Steps/Sec: 1.32
|
| 109 |
+
[[34m2026-02-01 16:16:38[0m] (step=0010000) Train Loss: 1.9410, Train Steps/Sec: 1.32
|
| 110 |
+
[[34m2026-02-01 16:16:45[0m] Beginning epoch 1...
|
| 111 |
+
[[34m2026-02-01 16:17:56[0m] (step=0010100) Train Loss: 1.8156, Train Steps/Sec: 1.28
|
| 112 |
+
[[34m2026-02-01 16:19:12[0m] (step=0010200) Train Loss: 1.7965, Train Steps/Sec: 1.32
|
| 113 |
+
[[34m2026-02-01 16:20:28[0m] (step=0010300) Train Loss: 1.9732, Train Steps/Sec: 1.32
|
| 114 |
+
[[34m2026-02-01 16:21:44[0m] (step=0010400) Train Loss: 2.6702, Train Steps/Sec: 1.32
|
| 115 |
+
[[34m2026-02-01 16:23:00[0m] (step=0010500) Train Loss: 1.9175, Train Steps/Sec: 1.32
|
| 116 |
+
[[34m2026-02-01 16:24:16[0m] (step=0010600) Train Loss: 1.8493, Train Steps/Sec: 1.32
|
| 117 |
+
[[34m2026-02-01 16:25:32[0m] (step=0010700) Train Loss: 1.8514, Train Steps/Sec: 1.32
|
| 118 |
+
[[34m2026-02-01 16:26:48[0m] (step=0010800) Train Loss: 2.0059, Train Steps/Sec: 1.32
|
| 119 |
+
[[34m2026-02-01 16:28:04[0m] (step=0010900) Train Loss: 1.8519, Train Steps/Sec: 1.32
|
| 120 |
+
[[34m2026-02-01 16:29:20[0m] (step=0011000) Train Loss: 1.8523, Train Steps/Sec: 1.32
|
| 121 |
+
[[34m2026-02-01 16:30:36[0m] (step=0011100) Train Loss: 1.7980, Train Steps/Sec: 1.32
|
| 122 |
+
[[34m2026-02-01 16:31:52[0m] (step=0011200) Train Loss: 1.8429, Train Steps/Sec: 1.32
|
| 123 |
+
[[34m2026-02-01 16:33:08[0m] (step=0011300) Train Loss: 1.9200, Train Steps/Sec: 1.32
|
| 124 |
+
[[34m2026-02-01 16:34:24[0m] (step=0011400) Train Loss: 1.8371, Train Steps/Sec: 1.32
|
| 125 |
+
[[34m2026-02-01 16:35:40[0m] (step=0011500) Train Loss: 2.0173, Train Steps/Sec: 1.32
|
| 126 |
+
[[34m2026-02-01 16:36:56[0m] (step=0011600) Train Loss: 1.8135, Train Steps/Sec: 1.32
|
| 127 |
+
[[34m2026-02-01 16:38:12[0m] (step=0011700) Train Loss: 1.9532, Train Steps/Sec: 1.32
|
| 128 |
+
[[34m2026-02-01 16:39:28[0m] (step=0011800) Train Loss: 2.0043, Train Steps/Sec: 1.32
|
| 129 |
+
[[34m2026-02-01 16:40:44[0m] (step=0011900) Train Loss: 1.8474, Train Steps/Sec: 1.32
|
| 130 |
+
[[34m2026-02-01 16:42:00[0m] (step=0012000) Train Loss: 1.8364, Train Steps/Sec: 1.32
|
| 131 |
+
[[34m2026-02-01 16:43:15[0m] (step=0012100) Train Loss: 2.6696, Train Steps/Sec: 1.32
|
| 132 |
+
[[34m2026-02-01 16:44:31[0m] (step=0012200) Train Loss: 1.8652, Train Steps/Sec: 1.32
|
| 133 |
+
[[34m2026-02-01 16:45:47[0m] (step=0012300) Train Loss: 1.9174, Train Steps/Sec: 1.32
|
| 134 |
+
[[34m2026-02-01 16:47:03[0m] (step=0012400) Train Loss: 1.8479, Train Steps/Sec: 1.31
|
| 135 |
+
[[34m2026-02-01 16:48:19[0m] (step=0012500) Train Loss: 1.8228, Train Steps/Sec: 1.32
|
| 136 |
+
[[34m2026-02-01 16:49:35[0m] (step=0012600) Train Loss: 1.9067, Train Steps/Sec: 1.32
|
| 137 |
+
[[34m2026-02-01 16:50:51[0m] (step=0012700) Train Loss: 1.7572, Train Steps/Sec: 1.32
|
| 138 |
+
[[34m2026-02-01 16:52:07[0m] (step=0012800) Train Loss: 1.8446, Train Steps/Sec: 1.32
|
| 139 |
+
[[34m2026-02-01 16:53:23[0m] (step=0012900) Train Loss: 1.8543, Train Steps/Sec: 1.32
|
| 140 |
+
[[34m2026-02-01 16:54:39[0m] (step=0013000) Train Loss: 1.8222, Train Steps/Sec: 1.32
|
| 141 |
+
[[34m2026-02-01 16:55:55[0m] (step=0013100) Train Loss: 2.0108, Train Steps/Sec: 1.32
|
| 142 |
+
[[34m2026-02-01 16:57:11[0m] (step=0013200) Train Loss: 2.3761, Train Steps/Sec: 1.32
|
| 143 |
+
[[34m2026-02-01 16:58:27[0m] (step=0013300) Train Loss: 1.8902, Train Steps/Sec: 1.32
|
| 144 |
+
[[34m2026-02-01 16:59:43[0m] (step=0013400) Train Loss: 1.8800, Train Steps/Sec: 1.32
|
| 145 |
+
[[34m2026-02-01 17:00:59[0m] (step=0013500) Train Loss: 1.7917, Train Steps/Sec: 1.32
|
| 146 |
+
[[34m2026-02-01 17:02:15[0m] (step=0013600) Train Loss: 1.9730, Train Steps/Sec: 1.32
|
| 147 |
+
[[34m2026-02-01 17:03:31[0m] (step=0013700) Train Loss: 1.8894, Train Steps/Sec: 1.32
|
| 148 |
+
[[34m2026-02-01 17:04:47[0m] (step=0013800) Train Loss: 2.1075, Train Steps/Sec: 1.32
|
| 149 |
+
[[34m2026-02-01 17:06:03[0m] (step=0013900) Train Loss: 1.8469, Train Steps/Sec: 1.32
|
| 150 |
+
[[34m2026-02-01 17:07:19[0m] (step=0014000) Train Loss: 1.8705, Train Steps/Sec: 1.32
|
| 151 |
+
[[34m2026-02-01 17:08:35[0m] (step=0014100) Train Loss: 1.8630, Train Steps/Sec: 1.32
|
| 152 |
+
[[34m2026-02-01 17:09:51[0m] (step=0014200) Train Loss: 1.8509, Train Steps/Sec: 1.32
|
| 153 |
+
[[34m2026-02-01 17:11:07[0m] (step=0014300) Train Loss: 2.2249, Train Steps/Sec: 1.32
|
| 154 |
+
[[34m2026-02-01 17:12:23[0m] (step=0014400) Train Loss: 1.8378, Train Steps/Sec: 1.32
|
| 155 |
+
[[34m2026-02-01 17:13:39[0m] (step=0014500) Train Loss: 1.8106, Train Steps/Sec: 1.32
|
| 156 |
+
[[34m2026-02-01 17:14:55[0m] (step=0014600) Train Loss: 1.8131, Train Steps/Sec: 1.32
|
| 157 |
+
[[34m2026-02-01 17:16:11[0m] (step=0014700) Train Loss: 1.9024, Train Steps/Sec: 1.32
|
| 158 |
+
[[34m2026-02-01 17:17:27[0m] (step=0014800) Train Loss: 2.2030, Train Steps/Sec: 1.32
|
| 159 |
+
[[34m2026-02-01 17:18:42[0m] (step=0014900) Train Loss: nan, Train Steps/Sec: 1.33
|
| 160 |
+
[[34m2026-02-01 17:19:58[0m] (step=0015000) Train Loss: nan, Train Steps/Sec: 1.33
|
| 161 |
+
[[34m2026-02-01 17:21:13[0m] (step=0015100) Train Loss: nan, Train Steps/Sec: 1.33
|
| 162 |
+
[[34m2026-02-01 17:22:28[0m] (step=0015200) Train Loss: nan, Train Steps/Sec: 1.33
|
| 163 |
+
[[34m2026-02-01 17:23:43[0m] (step=0015300) Train Loss: nan, Train Steps/Sec: 1.33
|
| 164 |
+
[[34m2026-02-01 17:24:58[0m] (step=0015400) Train Loss: nan, Train Steps/Sec: 1.33
|
| 165 |
+
[[34m2026-02-01 17:26:13[0m] (step=0015500) Train Loss: nan, Train Steps/Sec: 1.33
|
| 166 |
+
[[34m2026-02-01 17:27:29[0m] (step=0015600) Train Loss: nan, Train Steps/Sec: 1.33
|
| 167 |
+
[[34m2026-02-01 17:28:44[0m] (step=0015700) Train Loss: nan, Train Steps/Sec: 1.33
|
| 168 |
+
[[34m2026-02-01 17:29:59[0m] (step=0015800) Train Loss: nan, Train Steps/Sec: 1.33
|
| 169 |
+
[[34m2026-02-01 17:31:14[0m] (step=0015900) Train Loss: nan, Train Steps/Sec: 1.33
|
| 170 |
+
[[34m2026-02-01 17:32:29[0m] (step=0016000) Train Loss: nan, Train Steps/Sec: 1.33
|
| 171 |
+
[[34m2026-02-01 17:33:44[0m] (step=0016100) Train Loss: nan, Train Steps/Sec: 1.33
|
| 172 |
+
[[34m2026-02-01 17:34:59[0m] (step=0016200) Train Loss: nan, Train Steps/Sec: 1.33
|
| 173 |
+
[[34m2026-02-01 17:36:14[0m] (step=0016300) Train Loss: nan, Train Steps/Sec: 1.33
|
| 174 |
+
[[34m2026-02-01 17:37:29[0m] (step=0016400) Train Loss: nan, Train Steps/Sec: 1.33
|
| 175 |
+
[[34m2026-02-01 17:38:45[0m] (step=0016500) Train Loss: nan, Train Steps/Sec: 1.33
|
| 176 |
+
[[34m2026-02-01 17:40:00[0m] (step=0016600) Train Loss: nan, Train Steps/Sec: 1.33
|
| 177 |
+
[[34m2026-02-01 17:41:15[0m] (step=0016700) Train Loss: nan, Train Steps/Sec: 1.33
|
| 178 |
+
[[34m2026-02-01 17:42:30[0m] (step=0016800) Train Loss: nan, Train Steps/Sec: 1.33
|
| 179 |
+
[[34m2026-02-01 17:43:45[0m] (step=0016900) Train Loss: nan, Train Steps/Sec: 1.33
|
| 180 |
+
[[34m2026-02-01 17:45:00[0m] (step=0017000) Train Loss: nan, Train Steps/Sec: 1.33
|
| 181 |
+
[[34m2026-02-01 17:46:16[0m] (step=0017100) Train Loss: nan, Train Steps/Sec: 1.33
|
| 182 |
+
[[34m2026-02-01 17:47:31[0m] (step=0017200) Train Loss: nan, Train Steps/Sec: 1.33
|
| 183 |
+
[[34m2026-02-01 17:48:46[0m] (step=0017300) Train Loss: nan, Train Steps/Sec: 1.33
|
| 184 |
+
[[34m2026-02-01 17:50:01[0m] (step=0017400) Train Loss: nan, Train Steps/Sec: 1.33
|
| 185 |
+
[[34m2026-02-01 17:51:16[0m] (step=0017500) Train Loss: nan, Train Steps/Sec: 1.33
|
| 186 |
+
[[34m2026-02-01 17:52:31[0m] (step=0017600) Train Loss: nan, Train Steps/Sec: 1.33
|
| 187 |
+
[[34m2026-02-01 17:53:46[0m] (step=0017700) Train Loss: nan, Train Steps/Sec: 1.33
|
| 188 |
+
[[34m2026-02-01 17:55:01[0m] (step=0017800) Train Loss: nan, Train Steps/Sec: 1.33
|
| 189 |
+
[[34m2026-02-01 17:56:16[0m] (step=0017900) Train Loss: nan, Train Steps/Sec: 1.33
|
| 190 |
+
[[34m2026-02-01 17:57:31[0m] (step=0018000) Train Loss: nan, Train Steps/Sec: 1.33
|
| 191 |
+
[[34m2026-02-01 17:58:46[0m] (step=0018100) Train Loss: nan, Train Steps/Sec: 1.33
|
| 192 |
+
[[34m2026-02-01 18:00:02[0m] (step=0018200) Train Loss: nan, Train Steps/Sec: 1.33
|
| 193 |
+
[[34m2026-02-01 18:01:17[0m] (step=0018300) Train Loss: nan, Train Steps/Sec: 1.33
|
| 194 |
+
[[34m2026-02-01 18:02:32[0m] (step=0018400) Train Loss: nan, Train Steps/Sec: 1.33
|
| 195 |
+
[[34m2026-02-01 18:03:47[0m] (step=0018500) Train Loss: nan, Train Steps/Sec: 1.33
|
| 196 |
+
[[34m2026-02-01 18:05:02[0m] (step=0018600) Train Loss: nan, Train Steps/Sec: 1.33
|
| 197 |
+
[[34m2026-02-01 18:06:17[0m] (step=0018700) Train Loss: nan, Train Steps/Sec: 1.33
|
| 198 |
+
[[34m2026-02-01 18:07:32[0m] (step=0018800) Train Loss: nan, Train Steps/Sec: 1.33
|
| 199 |
+
[[34m2026-02-01 18:08:47[0m] (step=0018900) Train Loss: nan, Train Steps/Sec: 1.33
|
| 200 |
+
[[34m2026-02-01 18:10:02[0m] (step=0019000) Train Loss: nan, Train Steps/Sec: 1.33
|
| 201 |
+
[[34m2026-02-01 18:11:17[0m] (step=0019100) Train Loss: nan, Train Steps/Sec: 1.33
|
| 202 |
+
[[34m2026-02-01 18:12:33[0m] (step=0019200) Train Loss: nan, Train Steps/Sec: 1.33
|
| 203 |
+
[[34m2026-02-01 18:13:48[0m] (step=0019300) Train Loss: nan, Train Steps/Sec: 1.33
|
| 204 |
+
[[34m2026-02-01 18:15:03[0m] (step=0019400) Train Loss: nan, Train Steps/Sec: 1.33
|
| 205 |
+
[[34m2026-02-01 18:16:18[0m] (step=0019500) Train Loss: nan, Train Steps/Sec: 1.33
|
| 206 |
+
[[34m2026-02-01 18:17:33[0m] (step=0019600) Train Loss: nan, Train Steps/Sec: 1.33
|
| 207 |
+
[[34m2026-02-01 18:18:48[0m] (step=0019700) Train Loss: nan, Train Steps/Sec: 1.33
|
| 208 |
+
[[34m2026-02-01 18:20:03[0m] (step=0019800) Train Loss: nan, Train Steps/Sec: 1.33
|
| 209 |
+
[[34m2026-02-01 18:21:18[0m] (step=0019900) Train Loss: nan, Train Steps/Sec: 1.33
|
| 210 |
+
[[34m2026-02-01 18:22:34[0m] (step=0020000) Train Loss: nan, Train Steps/Sec: 1.33
|
| 211 |
+
[[34m2026-02-01 18:22:47[0m] Beginning epoch 2...
|
| 212 |
+
[[34m2026-02-01 18:23:51[0m] (step=0020100) Train Loss: nan, Train Steps/Sec: 1.29
|
| 213 |
+
[[34m2026-02-01 18:25:06[0m] (step=0020200) Train Loss: nan, Train Steps/Sec: 1.33
|
| 214 |
+
[[34m2026-02-01 18:26:21[0m] (step=0020300) Train Loss: nan, Train Steps/Sec: 1.33
|
| 215 |
+
[[34m2026-02-01 18:27:36[0m] (step=0020400) Train Loss: nan, Train Steps/Sec: 1.33
|
| 216 |
+
[[34m2026-02-01 18:28:51[0m] (step=0020500) Train Loss: nan, Train Steps/Sec: 1.33
|
| 217 |
+
[[34m2026-02-01 18:30:06[0m] (step=0020600) Train Loss: nan, Train Steps/Sec: 1.33
|
| 218 |
+
[[34m2026-02-01 18:31:21[0m] (step=0020700) Train Loss: nan, Train Steps/Sec: 1.33
|
| 219 |
+
[[34m2026-02-01 18:32:37[0m] (step=0020800) Train Loss: nan, Train Steps/Sec: 1.33
|
| 220 |
+
[[34m2026-02-01 18:33:52[0m] (step=0020900) Train Loss: nan, Train Steps/Sec: 1.33
|
| 221 |
+
[[34m2026-02-01 18:35:07[0m] (step=0021000) Train Loss: nan, Train Steps/Sec: 1.33
|
| 222 |
+
[[34m2026-02-01 18:36:22[0m] (step=0021100) Train Loss: nan, Train Steps/Sec: 1.33
|
| 223 |
+
[[34m2026-02-01 18:37:37[0m] (step=0021200) Train Loss: nan, Train Steps/Sec: 1.33
|
| 224 |
+
[[34m2026-02-01 18:38:52[0m] (step=0021300) Train Loss: nan, Train Steps/Sec: 1.33
|
| 225 |
+
[[34m2026-02-01 18:40:07[0m] (step=0021400) Train Loss: nan, Train Steps/Sec: 1.33
|
| 226 |
+
[[34m2026-02-01 18:41:22[0m] (step=0021500) Train Loss: nan, Train Steps/Sec: 1.33
|
| 227 |
+
[[34m2026-02-01 18:42:37[0m] (step=0021600) Train Loss: nan, Train Steps/Sec: 1.33
|
| 228 |
+
[[34m2026-02-01 18:43:53[0m] (step=0021700) Train Loss: nan, Train Steps/Sec: 1.33
|
| 229 |
+
[[34m2026-02-01 18:45:08[0m] (step=0021800) Train Loss: nan, Train Steps/Sec: 1.33
|
| 230 |
+
[[34m2026-02-01 18:46:23[0m] (step=0021900) Train Loss: nan, Train Steps/Sec: 1.33
|
| 231 |
+
[[34m2026-02-01 18:47:38[0m] (step=0022000) Train Loss: nan, Train Steps/Sec: 1.33
|
| 232 |
+
[[34m2026-02-01 18:48:53[0m] (step=0022100) Train Loss: nan, Train Steps/Sec: 1.33
|
| 233 |
+
[[34m2026-02-01 18:50:08[0m] (step=0022200) Train Loss: nan, Train Steps/Sec: 1.33
|
| 234 |
+
[[34m2026-02-01 18:51:24[0m] (step=0022300) Train Loss: nan, Train Steps/Sec: 1.33
|
| 235 |
+
[[34m2026-02-01 18:52:39[0m] (step=0022400) Train Loss: nan, Train Steps/Sec: 1.33
|
| 236 |
+
[[34m2026-02-01 18:53:54[0m] (step=0022500) Train Loss: nan, Train Steps/Sec: 1.33
|
| 237 |
+
[[34m2026-02-01 18:55:09[0m] (step=0022600) Train Loss: nan, Train Steps/Sec: 1.33
|
| 238 |
+
[[34m2026-02-01 18:56:24[0m] (step=0022700) Train Loss: nan, Train Steps/Sec: 1.33
|
| 239 |
+
[[34m2026-02-01 18:57:39[0m] (step=0022800) Train Loss: nan, Train Steps/Sec: 1.33
|
| 240 |
+
[[34m2026-02-01 18:58:54[0m] (step=0022900) Train Loss: nan, Train Steps/Sec: 1.33
|
| 241 |
+
[[34m2026-02-01 19:00:09[0m] (step=0023000) Train Loss: nan, Train Steps/Sec: 1.33
|
| 242 |
+
[[34m2026-02-01 19:01:24[0m] (step=0023100) Train Loss: nan, Train Steps/Sec: 1.33
|
| 243 |
+
[[34m2026-02-01 19:02:39[0m] (step=0023200) Train Loss: nan, Train Steps/Sec: 1.33
|
| 244 |
+
[[34m2026-02-01 19:03:54[0m] (step=0023300) Train Loss: nan, Train Steps/Sec: 1.33
|
| 245 |
+
[[34m2026-02-01 19:05:09[0m] (step=0023400) Train Loss: nan, Train Steps/Sec: 1.33
|
| 246 |
+
[[34m2026-02-01 19:06:25[0m] (step=0023500) Train Loss: nan, Train Steps/Sec: 1.33
|
| 247 |
+
[[34m2026-02-01 19:07:40[0m] (step=0023600) Train Loss: nan, Train Steps/Sec: 1.33
|
| 248 |
+
[[34m2026-02-01 19:08:55[0m] (step=0023700) Train Loss: nan, Train Steps/Sec: 1.33
|
| 249 |
+
[[34m2026-02-01 19:10:10[0m] (step=0023800) Train Loss: nan, Train Steps/Sec: 1.33
|
| 250 |
+
[[34m2026-02-01 19:11:25[0m] (step=0023900) Train Loss: nan, Train Steps/Sec: 1.33
|
| 251 |
+
[[34m2026-02-01 19:12:40[0m] (step=0024000) Train Loss: nan, Train Steps/Sec: 1.33
|
| 252 |
+
[[34m2026-02-01 19:13:56[0m] (step=0024100) Train Loss: nan, Train Steps/Sec: 1.33
|
| 253 |
+
[[34m2026-02-01 19:15:11[0m] (step=0024200) Train Loss: nan, Train Steps/Sec: 1.33
|
| 254 |
+
[[34m2026-02-01 19:16:26[0m] (step=0024300) Train Loss: nan, Train Steps/Sec: 1.33
|
| 255 |
+
[[34m2026-02-01 19:17:41[0m] (step=0024400) Train Loss: nan, Train Steps/Sec: 1.33
|
| 256 |
+
[[34m2026-02-01 19:18:56[0m] (step=0024500) Train Loss: nan, Train Steps/Sec: 1.33
|
| 257 |
+
[[34m2026-02-01 19:20:11[0m] (step=0024600) Train Loss: nan, Train Steps/Sec: 1.33
|
| 258 |
+
[[34m2026-02-01 19:21:26[0m] (step=0024700) Train Loss: nan, Train Steps/Sec: 1.33
|
| 259 |
+
[[34m2026-02-01 19:22:41[0m] (step=0024800) Train Loss: nan, Train Steps/Sec: 1.33
|
| 260 |
+
[[34m2026-02-01 19:23:56[0m] (step=0024900) Train Loss: nan, Train Steps/Sec: 1.33
|
| 261 |
+
[[34m2026-02-01 19:25:11[0m] (step=0025000) Train Loss: nan, Train Steps/Sec: 1.33
|
| 262 |
+
25000
|
| 263 |
+
[[34m2026-02-01 19:25:12[0m] Saved checkpoint to results_256_vp/depth-mu-2-000-SiT-XL-2-VP-velocity-None/checkpoints/0025000.pt
|
| 264 |
+
[[34m2026-02-01 19:26:27[0m] (step=0025100) Train Loss: nan, Train Steps/Sec: 1.32
|
| 265 |
+
[[34m2026-02-01 19:27:36[0m] Generating EMA samples...
|
| 266 |
+
[[34m2026-02-01 19:27:42[0m] (step=0025200) Train Loss: nan, Train Steps/Sec: 1.33
|
| 267 |
+
[[34m2026-02-01 19:28:57[0m] (step=0025300) Train Loss: nan, Train Steps/Sec: 1.33
|
| 268 |
+
[[34m2026-02-01 19:30:13[0m] (step=0025400) Train Loss: nan, Train Steps/Sec: 1.33
|
| 269 |
+
[[34m2026-02-01 19:31:28[0m] (step=0025500) Train Loss: nan, Train Steps/Sec: 1.33
|
| 270 |
+
[[34m2026-02-01 19:32:43[0m] (step=0025600) Train Loss: nan, Train Steps/Sec: 1.33
|
| 271 |
+
[[34m2026-02-01 19:33:58[0m] (step=0025700) Train Loss: nan, Train Steps/Sec: 1.33
|
| 272 |
+
[[34m2026-02-01 19:35:13[0m] (step=0025800) Train Loss: nan, Train Steps/Sec: 1.33
|
| 273 |
+
[[34m2026-02-01 19:36:28[0m] (step=0025900) Train Loss: nan, Train Steps/Sec: 1.33
|
| 274 |
+
[[34m2026-02-01 19:37:44[0m] (step=0026000) Train Loss: nan, Train Steps/Sec: 1.33
|
| 275 |
+
[[34m2026-02-01 19:38:59[0m] (step=0026100) Train Loss: nan, Train Steps/Sec: 1.33
|
| 276 |
+
[[34m2026-02-01 19:40:14[0m] (step=0026200) Train Loss: nan, Train Steps/Sec: 1.33
|
| 277 |
+
[[34m2026-02-01 19:41:29[0m] (step=0026300) Train Loss: nan, Train Steps/Sec: 1.33
|
| 278 |
+
[[34m2026-02-01 19:42:44[0m] (step=0026400) Train Loss: nan, Train Steps/Sec: 1.33
|
| 279 |
+
[[34m2026-02-01 19:43:59[0m] (step=0026500) Train Loss: nan, Train Steps/Sec: 1.33
|
| 280 |
+
[[34m2026-02-01 19:45:14[0m] (step=0026600) Train Loss: nan, Train Steps/Sec: 1.33
|
| 281 |
+
[[34m2026-02-01 19:46:29[0m] (step=0026700) Train Loss: nan, Train Steps/Sec: 1.33
|
| 282 |
+
[[34m2026-02-01 19:47:44[0m] (step=0026800) Train Loss: nan, Train Steps/Sec: 1.33
|
| 283 |
+
[[34m2026-02-01 19:49:00[0m] (step=0026900) Train Loss: nan, Train Steps/Sec: 1.33
|
| 284 |
+
[[34m2026-02-01 19:50:15[0m] (step=0027000) Train Loss: nan, Train Steps/Sec: 1.33
|
| 285 |
+
[[34m2026-02-01 19:51:30[0m] (step=0027100) Train Loss: nan, Train Steps/Sec: 1.33
|
| 286 |
+
[[34m2026-02-01 19:52:45[0m] (step=0027200) Train Loss: nan, Train Steps/Sec: 1.33
|
| 287 |
+
[[34m2026-02-01 19:54:00[0m] (step=0027300) Train Loss: nan, Train Steps/Sec: 1.33
|
| 288 |
+
[[34m2026-02-01 19:55:15[0m] (step=0027400) Train Loss: nan, Train Steps/Sec: 1.33
|
| 289 |
+
[[34m2026-02-01 19:56:30[0m] (step=0027500) Train Loss: nan, Train Steps/Sec: 1.33
|
| 290 |
+
[[34m2026-02-01 19:57:45[0m] (step=0027600) Train Loss: nan, Train Steps/Sec: 1.33
|
| 291 |
+
[[34m2026-02-01 19:59:00[0m] (step=0027700) Train Loss: nan, Train Steps/Sec: 1.33
|
| 292 |
+
[[34m2026-02-01 20:00:15[0m] (step=0027800) Train Loss: nan, Train Steps/Sec: 1.33
|
| 293 |
+
[[34m2026-02-01 20:01:31[0m] (step=0027900) Train Loss: nan, Train Steps/Sec: 1.33
|
| 294 |
+
[[34m2026-02-01 20:02:46[0m] (step=0028000) Train Loss: nan, Train Steps/Sec: 1.33
|
| 295 |
+
[[34m2026-02-01 20:04:01[0m] (step=0028100) Train Loss: nan, Train Steps/Sec: 1.33
|
| 296 |
+
[[34m2026-02-01 20:05:16[0m] (step=0028200) Train Loss: nan, Train Steps/Sec: 1.33
|
| 297 |
+
[[34m2026-02-01 20:06:31[0m] (step=0028300) Train Loss: nan, Train Steps/Sec: 1.33
|
| 298 |
+
[[34m2026-02-01 20:07:46[0m] (step=0028400) Train Loss: nan, Train Steps/Sec: 1.33
|
| 299 |
+
[[34m2026-02-01 20:09:02[0m] (step=0028500) Train Loss: nan, Train Steps/Sec: 1.33
|
| 300 |
+
[[34m2026-02-01 20:10:17[0m] (step=0028600) Train Loss: nan, Train Steps/Sec: 1.33
|
| 301 |
+
[[34m2026-02-01 20:11:32[0m] (step=0028700) Train Loss: nan, Train Steps/Sec: 1.33
|
| 302 |
+
[[34m2026-02-01 20:12:47[0m] (step=0028800) Train Loss: nan, Train Steps/Sec: 1.33
|
| 303 |
+
[[34m2026-02-01 20:14:02[0m] (step=0028900) Train Loss: nan, Train Steps/Sec: 1.33
|
| 304 |
+
[[34m2026-02-01 20:15:17[0m] (step=0029000) Train Loss: nan, Train Steps/Sec: 1.33
|
| 305 |
+
[[34m2026-02-01 20:16:32[0m] (step=0029100) Train Loss: nan, Train Steps/Sec: 1.33
|
| 306 |
+
[[34m2026-02-01 20:17:47[0m] (step=0029200) Train Loss: nan, Train Steps/Sec: 1.33
|
| 307 |
+
[[34m2026-02-01 20:19:02[0m] (step=0029300) Train Loss: nan, Train Steps/Sec: 1.33
|
| 308 |
+
[[34m2026-02-01 20:20:18[0m] (step=0029400) Train Loss: nan, Train Steps/Sec: 1.33
|
| 309 |
+
[[34m2026-02-01 20:21:33[0m] (step=0029500) Train Loss: nan, Train Steps/Sec: 1.33
|
| 310 |
+
[[34m2026-02-01 20:22:48[0m] (step=0029600) Train Loss: nan, Train Steps/Sec: 1.33
|
| 311 |
+
[[34m2026-02-01 20:24:03[0m] (step=0029700) Train Loss: nan, Train Steps/Sec: 1.33
|
| 312 |
+
[[34m2026-02-01 20:25:18[0m] (step=0029800) Train Loss: nan, Train Steps/Sec: 1.33
|
| 313 |
+
[[34m2026-02-01 20:26:33[0m] (step=0029900) Train Loss: nan, Train Steps/Sec: 1.33
|
| 314 |
+
[[34m2026-02-01 20:27:49[0m] (step=0030000) Train Loss: nan, Train Steps/Sec: 1.33
|
| 315 |
+
[[34m2026-02-01 20:28:09[0m] Beginning epoch 3...
|
| 316 |
+
[[34m2026-02-01 20:29:06[0m] (step=0030100) Train Loss: nan, Train Steps/Sec: 1.29
|
| 317 |
+
[[34m2026-02-01 20:30:21[0m] (step=0030200) Train Loss: nan, Train Steps/Sec: 1.33
|
| 318 |
+
[[34m2026-02-01 20:31:36[0m] (step=0030300) Train Loss: nan, Train Steps/Sec: 1.33
|
| 319 |
+
[[34m2026-02-01 20:32:51[0m] (step=0030400) Train Loss: nan, Train Steps/Sec: 1.33
|
| 320 |
+
[[34m2026-02-01 20:34:06[0m] (step=0030500) Train Loss: nan, Train Steps/Sec: 1.33
|
| 321 |
+
[[34m2026-02-01 20:35:22[0m] (step=0030600) Train Loss: nan, Train Steps/Sec: 1.33
|
| 322 |
+
[[34m2026-02-01 20:36:37[0m] (step=0030700) Train Loss: nan, Train Steps/Sec: 1.33
|
| 323 |
+
[[34m2026-02-01 20:37:52[0m] (step=0030800) Train Loss: nan, Train Steps/Sec: 1.33
|
| 324 |
+
[[34m2026-02-01 20:39:07[0m] (step=0030900) Train Loss: nan, Train Steps/Sec: 1.33
|
| 325 |
+
[[34m2026-02-01 20:40:22[0m] (step=0031000) Train Loss: nan, Train Steps/Sec: 1.33
|
| 326 |
+
[[34m2026-02-01 20:41:37[0m] (step=0031100) Train Loss: nan, Train Steps/Sec: 1.33
|
| 327 |
+
[[34m2026-02-01 20:42:52[0m] (step=0031200) Train Loss: nan, Train Steps/Sec: 1.33
|
| 328 |
+
[[34m2026-02-01 20:44:08[0m] (step=0031300) Train Loss: nan, Train Steps/Sec: 1.33
|
| 329 |
+
[[34m2026-02-01 20:45:23[0m] (step=0031400) Train Loss: nan, Train Steps/Sec: 1.33
|
| 330 |
+
[[34m2026-02-01 20:46:38[0m] (step=0031500) Train Loss: nan, Train Steps/Sec: 1.33
|
| 331 |
+
[[34m2026-02-01 20:47:53[0m] (step=0031600) Train Loss: nan, Train Steps/Sec: 1.33
|
| 332 |
+
[[34m2026-02-01 20:49:08[0m] (step=0031700) Train Loss: nan, Train Steps/Sec: 1.33
|
| 333 |
+
[[34m2026-02-01 20:50:23[0m] (step=0031800) Train Loss: nan, Train Steps/Sec: 1.33
|
| 334 |
+
[[34m2026-02-01 20:51:38[0m] (step=0031900) Train Loss: nan, Train Steps/Sec: 1.33
|
| 335 |
+
[[34m2026-02-01 20:52:54[0m] (step=0032000) Train Loss: nan, Train Steps/Sec: 1.33
|
| 336 |
+
[[34m2026-02-01 20:54:09[0m] (step=0032100) Train Loss: nan, Train Steps/Sec: 1.33
|
| 337 |
+
[[34m2026-02-01 20:55:24[0m] (step=0032200) Train Loss: nan, Train Steps/Sec: 1.33
|
| 338 |
+
[[34m2026-02-01 20:56:39[0m] (step=0032300) Train Loss: nan, Train Steps/Sec: 1.33
|
| 339 |
+
[[34m2026-02-01 20:57:54[0m] (step=0032400) Train Loss: nan, Train Steps/Sec: 1.33
|
| 340 |
+
[[34m2026-02-01 20:59:09[0m] (step=0032500) Train Loss: nan, Train Steps/Sec: 1.33
|
| 341 |
+
[[34m2026-02-01 21:00:24[0m] (step=0032600) Train Loss: nan, Train Steps/Sec: 1.33
|
| 342 |
+
[[34m2026-02-01 21:01:39[0m] (step=0032700) Train Loss: nan, Train Steps/Sec: 1.33
|
| 343 |
+
[[34m2026-02-01 21:02:54[0m] (step=0032800) Train Loss: nan, Train Steps/Sec: 1.33
|
| 344 |
+
[[34m2026-02-01 21:04:09[0m] (step=0032900) Train Loss: nan, Train Steps/Sec: 1.33
|
| 345 |
+
[[34m2026-02-01 21:05:25[0m] (step=0033000) Train Loss: nan, Train Steps/Sec: 1.33
|
| 346 |
+
[[34m2026-02-01 21:06:40[0m] (step=0033100) Train Loss: nan, Train Steps/Sec: 1.33
|
| 347 |
+
[[34m2026-02-01 21:07:55[0m] (step=0033200) Train Loss: nan, Train Steps/Sec: 1.33
|
| 348 |
+
[[34m2026-02-01 21:09:10[0m] (step=0033300) Train Loss: nan, Train Steps/Sec: 1.33
|
| 349 |
+
[[34m2026-02-01 21:10:25[0m] (step=0033400) Train Loss: nan, Train Steps/Sec: 1.33
|
| 350 |
+
[[34m2026-02-01 21:11:40[0m] (step=0033500) Train Loss: nan, Train Steps/Sec: 1.33
|
| 351 |
+
[[34m2026-02-01 21:12:56[0m] (step=0033600) Train Loss: nan, Train Steps/Sec: 1.33
|
| 352 |
+
[[34m2026-02-01 21:14:11[0m] (step=0033700) Train Loss: nan, Train Steps/Sec: 1.33
|
| 353 |
+
[[34m2026-02-01 21:15:26[0m] (step=0033800) Train Loss: nan, Train Steps/Sec: 1.33
|
| 354 |
+
[[34m2026-02-01 21:16:41[0m] (step=0033900) Train Loss: nan, Train Steps/Sec: 1.33
|
| 355 |
+
[[34m2026-02-01 21:17:56[0m] (step=0034000) Train Loss: nan, Train Steps/Sec: 1.33
|
| 356 |
+
[[34m2026-02-01 21:19:11[0m] (step=0034100) Train Loss: nan, Train Steps/Sec: 1.33
|
| 357 |
+
[[34m2026-02-01 21:20:27[0m] (step=0034200) Train Loss: nan, Train Steps/Sec: 1.33
|
| 358 |
+
[[34m2026-02-01 21:21:42[0m] (step=0034300) Train Loss: nan, Train Steps/Sec: 1.33
|
| 359 |
+
[[34m2026-02-01 21:22:57[0m] (step=0034400) Train Loss: nan, Train Steps/Sec: 1.33
|
| 360 |
+
[[34m2026-02-01 21:24:12[0m] (step=0034500) Train Loss: nan, Train Steps/Sec: 1.33
|
| 361 |
+
[[34m2026-02-01 21:25:27[0m] (step=0034600) Train Loss: nan, Train Steps/Sec: 1.33
|
| 362 |
+
[[34m2026-02-01 21:26:42[0m] (step=0034700) Train Loss: nan, Train Steps/Sec: 1.33
|
| 363 |
+
[[34m2026-02-01 21:27:57[0m] (step=0034800) Train Loss: nan, Train Steps/Sec: 1.33
|
| 364 |
+
[[34m2026-02-01 21:29:12[0m] (step=0034900) Train Loss: nan, Train Steps/Sec: 1.33
|
| 365 |
+
[[34m2026-02-01 21:30:28[0m] (step=0035000) Train Loss: nan, Train Steps/Sec: 1.33
|
| 366 |
+
[[34m2026-02-01 21:31:43[0m] (step=0035100) Train Loss: nan, Train Steps/Sec: 1.33
|
| 367 |
+
[[34m2026-02-01 21:32:58[0m] (step=0035200) Train Loss: nan, Train Steps/Sec: 1.33
|
| 368 |
+
[[34m2026-02-01 21:34:13[0m] (step=0035300) Train Loss: nan, Train Steps/Sec: 1.33
|
| 369 |
+
[[34m2026-02-01 21:35:28[0m] (step=0035400) Train Loss: nan, Train Steps/Sec: 1.33
|
| 370 |
+
[[34m2026-02-01 21:36:43[0m] (step=0035500) Train Loss: nan, Train Steps/Sec: 1.33
|
| 371 |
+
[[34m2026-02-01 21:37:58[0m] (step=0035600) Train Loss: nan, Train Steps/Sec: 1.33
|
| 372 |
+
[[34m2026-02-01 21:39:13[0m] (step=0035700) Train Loss: nan, Train Steps/Sec: 1.33
|
| 373 |
+
[[34m2026-02-01 21:40:29[0m] (step=0035800) Train Loss: nan, Train Steps/Sec: 1.33
|
| 374 |
+
[[34m2026-02-01 21:41:44[0m] (step=0035900) Train Loss: nan, Train Steps/Sec: 1.33
|
| 375 |
+
[[34m2026-02-01 21:42:59[0m] (step=0036000) Train Loss: nan, Train Steps/Sec: 1.33
|
| 376 |
+
[[34m2026-02-01 21:44:14[0m] (step=0036100) Train Loss: nan, Train Steps/Sec: 1.33
|
| 377 |
+
[[34m2026-02-01 21:45:29[0m] (step=0036200) Train Loss: nan, Train Steps/Sec: 1.33
|
| 378 |
+
[[34m2026-02-01 21:46:44[0m] (step=0036300) Train Loss: nan, Train Steps/Sec: 1.33
|
| 379 |
+
[[34m2026-02-01 21:48:00[0m] (step=0036400) Train Loss: nan, Train Steps/Sec: 1.33
|
| 380 |
+
[[34m2026-02-01 21:49:15[0m] (step=0036500) Train Loss: nan, Train Steps/Sec: 1.33
|
| 381 |
+
[[34m2026-02-01 21:50:30[0m] (step=0036600) Train Loss: nan, Train Steps/Sec: 1.33
|
| 382 |
+
[[34m2026-02-01 21:51:45[0m] (step=0036700) Train Loss: nan, Train Steps/Sec: 1.33
|
| 383 |
+
[[34m2026-02-01 21:53:00[0m] (step=0036800) Train Loss: nan, Train Steps/Sec: 1.33
|
| 384 |
+
[[34m2026-02-01 21:54:15[0m] (step=0036900) Train Loss: nan, Train Steps/Sec: 1.33
|
| 385 |
+
[[34m2026-02-01 21:55:31[0m] (step=0037000) Train Loss: nan, Train Steps/Sec: 1.33
|
| 386 |
+
[[34m2026-02-01 21:56:46[0m] (step=0037100) Train Loss: nan, Train Steps/Sec: 1.33
|
| 387 |
+
[[34m2026-02-01 21:58:01[0m] (step=0037200) Train Loss: nan, Train Steps/Sec: 1.33
|
| 388 |
+
[[34m2026-02-01 21:59:16[0m] (step=0037300) Train Loss: nan, Train Steps/Sec: 1.33
|
| 389 |
+
[[34m2026-02-01 22:00:31[0m] (step=0037400) Train Loss: nan, Train Steps/Sec: 1.33
|
| 390 |
+
[[34m2026-02-01 22:01:46[0m] (step=0037500) Train Loss: nan, Train Steps/Sec: 1.33
|
| 391 |
+
[[34m2026-02-01 22:03:01[0m] (step=0037600) Train Loss: nan, Train Steps/Sec: 1.33
|
| 392 |
+
[[34m2026-02-01 22:04:17[0m] (step=0037700) Train Loss: nan, Train Steps/Sec: 1.33
|
| 393 |
+
[[34m2026-02-01 22:05:32[0m] (step=0037800) Train Loss: nan, Train Steps/Sec: 1.33
|
| 394 |
+
[[34m2026-02-01 22:06:47[0m] (step=0037900) Train Loss: nan, Train Steps/Sec: 1.33
|
| 395 |
+
[[34m2026-02-01 22:08:02[0m] (step=0038000) Train Loss: nan, Train Steps/Sec: 1.33
|
| 396 |
+
[[34m2026-02-01 22:09:17[0m] (step=0038100) Train Loss: nan, Train Steps/Sec: 1.33
|
| 397 |
+
[[34m2026-02-01 22:10:32[0m] (step=0038200) Train Loss: nan, Train Steps/Sec: 1.33
|
| 398 |
+
[[34m2026-02-01 22:11:47[0m] (step=0038300) Train Loss: nan, Train Steps/Sec: 1.33
|
| 399 |
+
[[34m2026-02-01 22:13:02[0m] (step=0038400) Train Loss: nan, Train Steps/Sec: 1.33
|
| 400 |
+
[[34m2026-02-01 22:14:18[0m] (step=0038500) Train Loss: nan, Train Steps/Sec: 1.33
|
| 401 |
+
[[34m2026-02-01 22:15:33[0m] (step=0038600) Train Loss: nan, Train Steps/Sec: 1.33
|
| 402 |
+
[[34m2026-02-01 22:16:48[0m] (step=0038700) Train Loss: nan, Train Steps/Sec: 1.33
|
| 403 |
+
[[34m2026-02-01 22:18:03[0m] (step=0038800) Train Loss: nan, Train Steps/Sec: 1.33
|
| 404 |
+
[[34m2026-02-01 22:19:18[0m] (step=0038900) Train Loss: nan, Train Steps/Sec: 1.33
|
| 405 |
+
[[34m2026-02-01 22:20:33[0m] (step=0039000) Train Loss: nan, Train Steps/Sec: 1.33
|
| 406 |
+
[[34m2026-02-01 22:21:48[0m] (step=0039100) Train Loss: nan, Train Steps/Sec: 1.33
|
| 407 |
+
[[34m2026-02-01 22:23:04[0m] (step=0039200) Train Loss: nan, Train Steps/Sec: 1.33
|
| 408 |
+
[[34m2026-02-01 22:24:19[0m] (step=0039300) Train Loss: nan, Train Steps/Sec: 1.33
|
| 409 |
+
[[34m2026-02-01 22:25:34[0m] (step=0039400) Train Loss: nan, Train Steps/Sec: 1.33
|
| 410 |
+
[[34m2026-02-01 22:26:49[0m] (step=0039500) Train Loss: nan, Train Steps/Sec: 1.33
|
| 411 |
+
[[34m2026-02-01 22:28:04[0m] (step=0039600) Train Loss: nan, Train Steps/Sec: 1.33
|
| 412 |
+
[[34m2026-02-01 22:29:19[0m] (step=0039700) Train Loss: nan, Train Steps/Sec: 1.33
|
| 413 |
+
[[34m2026-02-01 22:30:34[0m] (step=0039800) Train Loss: nan, Train Steps/Sec: 1.33
|
| 414 |
+
[[34m2026-02-01 22:31:49[0m] (step=0039900) Train Loss: nan, Train Steps/Sec: 1.33
|
| 415 |
+
[[34m2026-02-01 22:33:04[0m] (step=0040000) Train Loss: nan, Train Steps/Sec: 1.33
|
| 416 |
+
[[34m2026-02-01 22:33:32[0m] Beginning epoch 4...
|
| 417 |
+
[[34m2026-02-01 22:34:22[0m] (step=0040100) Train Loss: nan, Train Steps/Sec: 1.29
|
| 418 |
+
[[34m2026-02-01 22:35:37[0m] (step=0040200) Train Loss: nan, Train Steps/Sec: 1.33
|
| 419 |
+
[[34m2026-02-01 22:36:52[0m] (step=0040300) Train Loss: nan, Train Steps/Sec: 1.33
|
| 420 |
+
[[34m2026-02-01 22:38:07[0m] (step=0040400) Train Loss: nan, Train Steps/Sec: 1.33
|
| 421 |
+
[[34m2026-02-01 22:39:22[0m] (step=0040500) Train Loss: nan, Train Steps/Sec: 1.33
|
| 422 |
+
[[34m2026-02-01 22:40:37[0m] (step=0040600) Train Loss: nan, Train Steps/Sec: 1.33
|
| 423 |
+
[[34m2026-02-01 22:41:52[0m] (step=0040700) Train Loss: nan, Train Steps/Sec: 1.33
|
| 424 |
+
[[34m2026-02-01 22:43:08[0m] (step=0040800) Train Loss: nan, Train Steps/Sec: 1.33
|
| 425 |
+
[[34m2026-02-01 22:44:23[0m] (step=0040900) Train Loss: nan, Train Steps/Sec: 1.33
|
| 426 |
+
[[34m2026-02-01 22:45:38[0m] (step=0041000) Train Loss: nan, Train Steps/Sec: 1.33
|
| 427 |
+
[[34m2026-02-01 22:46:53[0m] (step=0041100) Train Loss: nan, Train Steps/Sec: 1.33
|
| 428 |
+
[[34m2026-02-01 22:48:08[0m] (step=0041200) Train Loss: nan, Train Steps/Sec: 1.33
|
| 429 |
+
[[34m2026-02-01 22:49:23[0m] (step=0041300) Train Loss: nan, Train Steps/Sec: 1.33
|
| 430 |
+
[[34m2026-02-01 22:50:39[0m] (step=0041400) Train Loss: nan, Train Steps/Sec: 1.33
|
| 431 |
+
[[34m2026-02-01 22:51:54[0m] (step=0041500) Train Loss: nan, Train Steps/Sec: 1.33
|
| 432 |
+
[[34m2026-02-01 22:53:09[0m] (step=0041600) Train Loss: nan, Train Steps/Sec: 1.33
|
| 433 |
+
[[34m2026-02-01 22:54:24[0m] (step=0041700) Train Loss: nan, Train Steps/Sec: 1.33
|
| 434 |
+
[[34m2026-02-01 22:55:39[0m] (step=0041800) Train Loss: nan, Train Steps/Sec: 1.33
|
| 435 |
+
[[34m2026-02-01 22:56:54[0m] (step=0041900) Train Loss: nan, Train Steps/Sec: 1.33
|
| 436 |
+
[[34m2026-02-01 22:58:09[0m] (step=0042000) Train Loss: nan, Train Steps/Sec: 1.33
|
| 437 |
+
[[34m2026-02-01 22:59:24[0m] (step=0042100) Train Loss: nan, Train Steps/Sec: 1.33
|
| 438 |
+
[[34m2026-02-01 23:00:40[0m] (step=0042200) Train Loss: nan, Train Steps/Sec: 1.33
|
| 439 |
+
[[34m2026-02-01 23:01:55[0m] (step=0042300) Train Loss: nan, Train Steps/Sec: 1.33
|
| 440 |
+
[[34m2026-02-01 23:03:10[0m] (step=0042400) Train Loss: nan, Train Steps/Sec: 1.33
|
| 441 |
+
[[34m2026-02-01 23:04:25[0m] (step=0042500) Train Loss: nan, Train Steps/Sec: 1.33
|
| 442 |
+
[[34m2026-02-01 23:05:40[0m] (step=0042600) Train Loss: nan, Train Steps/Sec: 1.33
|
| 443 |
+
[[34m2026-02-01 23:06:55[0m] (step=0042700) Train Loss: nan, Train Steps/Sec: 1.33
|
| 444 |
+
[[34m2026-02-01 23:08:10[0m] (step=0042800) Train Loss: nan, Train Steps/Sec: 1.33
|
| 445 |
+
[[34m2026-02-01 23:09:26[0m] (step=0042900) Train Loss: nan, Train Steps/Sec: 1.33
|
| 446 |
+
[[34m2026-02-01 23:10:41[0m] (step=0043000) Train Loss: nan, Train Steps/Sec: 1.33
|
| 447 |
+
[[34m2026-02-01 23:11:56[0m] (step=0043100) Train Loss: nan, Train Steps/Sec: 1.33
|
| 448 |
+
[[34m2026-02-01 23:13:11[0m] (step=0043200) Train Loss: nan, Train Steps/Sec: 1.33
|
| 449 |
+
[[34m2026-02-01 23:14:26[0m] (step=0043300) Train Loss: nan, Train Steps/Sec: 1.33
|
| 450 |
+
[[34m2026-02-01 23:15:41[0m] (step=0043400) Train Loss: nan, Train Steps/Sec: 1.33
|
| 451 |
+
[[34m2026-02-01 23:16:56[0m] (step=0043500) Train Loss: nan, Train Steps/Sec: 1.33
|
| 452 |
+
[[34m2026-02-01 23:18:11[0m] (step=0043600) Train Loss: nan, Train Steps/Sec: 1.33
|
| 453 |
+
[[34m2026-02-01 23:19:27[0m] (step=0043700) Train Loss: nan, Train Steps/Sec: 1.33
|
| 454 |
+
[[34m2026-02-01 23:20:42[0m] (step=0043800) Train Loss: nan, Train Steps/Sec: 1.33
|
| 455 |
+
[[34m2026-02-01 23:21:57[0m] (step=0043900) Train Loss: nan, Train Steps/Sec: 1.33
|
| 456 |
+
[[34m2026-02-01 23:23:12[0m] (step=0044000) Train Loss: nan, Train Steps/Sec: 1.33
|
| 457 |
+
[[34m2026-02-01 23:24:27[0m] (step=0044100) Train Loss: nan, Train Steps/Sec: 1.33
|
| 458 |
+
[[34m2026-02-01 23:25:42[0m] (step=0044200) Train Loss: nan, Train Steps/Sec: 1.33
|
| 459 |
+
[[34m2026-02-01 23:26:57[0m] (step=0044300) Train Loss: nan, Train Steps/Sec: 1.33
|
| 460 |
+
[[34m2026-02-01 23:28:12[0m] (step=0044400) Train Loss: nan, Train Steps/Sec: 1.33
|
| 461 |
+
[[34m2026-02-01 23:29:28[0m] (step=0044500) Train Loss: nan, Train Steps/Sec: 1.33
|
| 462 |
+
[[34m2026-02-01 23:30:43[0m] (step=0044600) Train Loss: nan, Train Steps/Sec: 1.33
|
| 463 |
+
[[34m2026-02-01 23:31:58[0m] (step=0044700) Train Loss: nan, Train Steps/Sec: 1.33
|
| 464 |
+
[[34m2026-02-01 23:33:13[0m] (step=0044800) Train Loss: nan, Train Steps/Sec: 1.33
|
| 465 |
+
[[34m2026-02-01 23:34:28[0m] (step=0044900) Train Loss: nan, Train Steps/Sec: 1.33
|
| 466 |
+
[[34m2026-02-01 23:35:43[0m] (step=0045000) Train Loss: nan, Train Steps/Sec: 1.33
|
| 467 |
+
[[34m2026-02-01 23:36:58[0m] (step=0045100) Train Loss: nan, Train Steps/Sec: 1.33
|
| 468 |
+
[[34m2026-02-01 23:38:13[0m] (step=0045200) Train Loss: nan, Train Steps/Sec: 1.33
|
| 469 |
+
[[34m2026-02-01 23:39:28[0m] (step=0045300) Train Loss: nan, Train Steps/Sec: 1.33
|
| 470 |
+
[[34m2026-02-01 23:40:44[0m] (step=0045400) Train Loss: nan, Train Steps/Sec: 1.33
|
| 471 |
+
[[34m2026-02-01 23:41:59[0m] (step=0045500) Train Loss: nan, Train Steps/Sec: 1.33
|
| 472 |
+
[[34m2026-02-01 23:43:14[0m] (step=0045600) Train Loss: nan, Train Steps/Sec: 1.33
|
| 473 |
+
[[34m2026-02-01 23:44:29[0m] (step=0045700) Train Loss: nan, Train Steps/Sec: 1.33
|
| 474 |
+
[[34m2026-02-01 23:45:44[0m] (step=0045800) Train Loss: nan, Train Steps/Sec: 1.33
|
| 475 |
+
[[34m2026-02-01 23:46:59[0m] (step=0045900) Train Loss: nan, Train Steps/Sec: 1.33
|
| 476 |
+
[[34m2026-02-01 23:48:14[0m] (step=0046000) Train Loss: nan, Train Steps/Sec: 1.33
|
| 477 |
+
[[34m2026-02-01 23:49:29[0m] (step=0046100) Train Loss: nan, Train Steps/Sec: 1.33
|
| 478 |
+
[[34m2026-02-01 23:50:45[0m] (step=0046200) Train Loss: nan, Train Steps/Sec: 1.33
|
| 479 |
+
[[34m2026-02-01 23:52:00[0m] (step=0046300) Train Loss: nan, Train Steps/Sec: 1.33
|
| 480 |
+
[[34m2026-02-01 23:53:15[0m] (step=0046400) Train Loss: nan, Train Steps/Sec: 1.33
|
| 481 |
+
[[34m2026-02-01 23:54:30[0m] (step=0046500) Train Loss: nan, Train Steps/Sec: 1.33
|
| 482 |
+
[[34m2026-02-01 23:55:45[0m] (step=0046600) Train Loss: nan, Train Steps/Sec: 1.33
|
| 483 |
+
[[34m2026-02-01 23:57:00[0m] (step=0046700) Train Loss: nan, Train Steps/Sec: 1.33
|
| 484 |
+
[[34m2026-02-01 23:58:15[0m] (step=0046800) Train Loss: nan, Train Steps/Sec: 1.33
|
| 485 |
+
[[34m2026-02-01 23:59:31[0m] (step=0046900) Train Loss: nan, Train Steps/Sec: 1.33
|
| 486 |
+
[[34m2026-02-02 00:00:46[0m] (step=0047000) Train Loss: nan, Train Steps/Sec: 1.33
|
| 487 |
+
[[34m2026-02-02 00:02:01[0m] (step=0047100) Train Loss: nan, Train Steps/Sec: 1.33
|
| 488 |
+
[[34m2026-02-02 00:03:16[0m] (step=0047200) Train Loss: nan, Train Steps/Sec: 1.33
|
| 489 |
+
[[34m2026-02-02 00:04:31[0m] (step=0047300) Train Loss: nan, Train Steps/Sec: 1.33
|
| 490 |
+
[[34m2026-02-02 00:05:46[0m] (step=0047400) Train Loss: nan, Train Steps/Sec: 1.33
|
| 491 |
+
[[34m2026-02-02 00:07:01[0m] (step=0047500) Train Loss: nan, Train Steps/Sec: 1.33
|
| 492 |
+
[[34m2026-02-02 00:08:16[0m] (step=0047600) Train Loss: nan, Train Steps/Sec: 1.33
|
| 493 |
+
[[34m2026-02-02 00:09:32[0m] (step=0047700) Train Loss: nan, Train Steps/Sec: 1.33
|
| 494 |
+
[[34m2026-02-02 00:10:47[0m] (step=0047800) Train Loss: nan, Train Steps/Sec: 1.33
|
| 495 |
+
[[34m2026-02-02 00:12:02[0m] (step=0047900) Train Loss: nan, Train Steps/Sec: 1.33
|
| 496 |
+
[[34m2026-02-02 00:13:17[0m] (step=0048000) Train Loss: nan, Train Steps/Sec: 1.33
|
| 497 |
+
[[34m2026-02-02 00:14:32[0m] (step=0048100) Train Loss: nan, Train Steps/Sec: 1.33
|
| 498 |
+
[[34m2026-02-02 00:15:47[0m] (step=0048200) Train Loss: nan, Train Steps/Sec: 1.33
|
| 499 |
+
[[34m2026-02-02 00:17:02[0m] (step=0048300) Train Loss: nan, Train Steps/Sec: 1.33
|
| 500 |
+
[[34m2026-02-02 00:18:17[0m] (step=0048400) Train Loss: nan, Train Steps/Sec: 1.33
|
| 501 |
+
[[34m2026-02-02 00:19:32[0m] (step=0048500) Train Loss: nan, Train Steps/Sec: 1.33
|
| 502 |
+
[[34m2026-02-02 00:20:48[0m] (step=0048600) Train Loss: nan, Train Steps/Sec: 1.33
|
| 503 |
+
[[34m2026-02-02 00:22:03[0m] (step=0048700) Train Loss: nan, Train Steps/Sec: 1.33
|
| 504 |
+
[[34m2026-02-02 00:23:18[0m] (step=0048800) Train Loss: nan, Train Steps/Sec: 1.33
|
| 505 |
+
[[34m2026-02-02 00:24:33[0m] (step=0048900) Train Loss: nan, Train Steps/Sec: 1.33
|
| 506 |
+
[[34m2026-02-02 00:25:48[0m] (step=0049000) Train Loss: nan, Train Steps/Sec: 1.33
|
| 507 |
+
[[34m2026-02-02 00:27:03[0m] (step=0049100) Train Loss: nan, Train Steps/Sec: 1.33
|
| 508 |
+
[[34m2026-02-02 00:28:18[0m] (step=0049200) Train Loss: nan, Train Steps/Sec: 1.33
|
| 509 |
+
[[34m2026-02-02 00:29:33[0m] (step=0049300) Train Loss: nan, Train Steps/Sec: 1.33
|
| 510 |
+
[[34m2026-02-02 00:30:48[0m] (step=0049400) Train Loss: nan, Train Steps/Sec: 1.33
|
| 511 |
+
[[34m2026-02-02 00:32:04[0m] (step=0049500) Train Loss: nan, Train Steps/Sec: 1.33
|
| 512 |
+
[[34m2026-02-02 00:33:19[0m] (step=0049600) Train Loss: nan, Train Steps/Sec: 1.33
|
| 513 |
+
[[34m2026-02-02 00:34:34[0m] (step=0049700) Train Loss: nan, Train Steps/Sec: 1.33
|
| 514 |
+
[[34m2026-02-02 00:35:49[0m] (step=0049800) Train Loss: nan, Train Steps/Sec: 1.33
|
| 515 |
+
[[34m2026-02-02 00:37:04[0m] (step=0049900) Train Loss: nan, Train Steps/Sec: 1.33
|
| 516 |
+
[[34m2026-02-02 00:38:19[0m] (step=0050000) Train Loss: nan, Train Steps/Sec: 1.33
|
| 517 |
+
50000
|
| 518 |
+
[[34m2026-02-02 00:38:20[0m] Saved checkpoint to results_256_vp/depth-mu-2-000-SiT-XL-2-VP-velocity-None/checkpoints/0050000.pt
|
| 519 |
+
[[34m2026-02-02 00:38:54[0m] Beginning epoch 5...
|
| 520 |
+
[[34m2026-02-02 00:39:37[0m] (step=0050100) Train Loss: nan, Train Steps/Sec: 1.28
|
| 521 |
+
[[34m2026-02-02 00:40:53[0m] (step=0050200) Train Loss: nan, Train Steps/Sec: 1.33
|
| 522 |
+
[[34m2026-02-02 00:42:08[0m] (step=0050300) Train Loss: nan, Train Steps/Sec: 1.33
|
| 523 |
+
[[34m2026-02-02 00:43:11[0m] Generating EMA samples...
|
| 524 |
+
[[34m2026-02-02 00:43:23[0m] (step=0050400) Train Loss: nan, Train Steps/Sec: 1.33
|
| 525 |
+
[[34m2026-02-02 00:44:38[0m] (step=0050500) Train Loss: nan, Train Steps/Sec: 1.33
|
| 526 |
+
[[34m2026-02-02 00:45:53[0m] (step=0050600) Train Loss: nan, Train Steps/Sec: 1.33
|
| 527 |
+
[[34m2026-02-02 00:47:08[0m] (step=0050700) Train Loss: nan, Train Steps/Sec: 1.33
|
| 528 |
+
[[34m2026-02-02 00:48:23[0m] (step=0050800) Train Loss: nan, Train Steps/Sec: 1.33
|
| 529 |
+
[[34m2026-02-02 00:49:38[0m] (step=0050900) Train Loss: nan, Train Steps/Sec: 1.33
|
| 530 |
+
[[34m2026-02-02 00:50:53[0m] (step=0051000) Train Loss: nan, Train Steps/Sec: 1.33
|
| 531 |
+
[[34m2026-02-02 00:52:09[0m] (step=0051100) Train Loss: nan, Train Steps/Sec: 1.33
|
| 532 |
+
[[34m2026-02-02 00:53:24[0m] (step=0051200) Train Loss: nan, Train Steps/Sec: 1.33
|
| 533 |
+
[[34m2026-02-02 00:54:39[0m] (step=0051300) Train Loss: nan, Train Steps/Sec: 1.33
|
| 534 |
+
[[34m2026-02-02 00:55:54[0m] (step=0051400) Train Loss: nan, Train Steps/Sec: 1.33
|
| 535 |
+
[[34m2026-02-02 00:57:09[0m] (step=0051500) Train Loss: nan, Train Steps/Sec: 1.33
|
| 536 |
+
[[34m2026-02-02 00:58:24[0m] (step=0051600) Train Loss: nan, Train Steps/Sec: 1.33
|
| 537 |
+
[[34m2026-02-02 00:59:39[0m] (step=0051700) Train Loss: nan, Train Steps/Sec: 1.33
|
| 538 |
+
[[34m2026-02-02 01:00:54[0m] (step=0051800) Train Loss: nan, Train Steps/Sec: 1.33
|
| 539 |
+
[[34m2026-02-02 01:02:10[0m] (step=0051900) Train Loss: nan, Train Steps/Sec: 1.33
|
| 540 |
+
[[34m2026-02-02 01:03:25[0m] (step=0052000) Train Loss: nan, Train Steps/Sec: 1.33
|
| 541 |
+
[[34m2026-02-02 01:04:40[0m] (step=0052100) Train Loss: nan, Train Steps/Sec: 1.33
|
| 542 |
+
[[34m2026-02-02 01:05:55[0m] (step=0052200) Train Loss: nan, Train Steps/Sec: 1.33
|
| 543 |
+
[[34m2026-02-02 01:07:10[0m] (step=0052300) Train Loss: nan, Train Steps/Sec: 1.33
|
| 544 |
+
[[34m2026-02-02 01:08:25[0m] (step=0052400) Train Loss: nan, Train Steps/Sec: 1.33
|
| 545 |
+
[[34m2026-02-02 01:09:41[0m] (step=0052500) Train Loss: nan, Train Steps/Sec: 1.33
|
| 546 |
+
[[34m2026-02-02 01:10:56[0m] (step=0052600) Train Loss: nan, Train Steps/Sec: 1.33
|
| 547 |
+
[[34m2026-02-02 01:12:11[0m] (step=0052700) Train Loss: nan, Train Steps/Sec: 1.33
|
| 548 |
+
[[34m2026-02-02 01:13:26[0m] (step=0052800) Train Loss: nan, Train Steps/Sec: 1.33
|
| 549 |
+
[[34m2026-02-02 01:14:41[0m] (step=0052900) Train Loss: nan, Train Steps/Sec: 1.33
|
| 550 |
+
[[34m2026-02-02 01:15:57[0m] (step=0053000) Train Loss: nan, Train Steps/Sec: 1.33
|
| 551 |
+
[[34m2026-02-02 01:17:12[0m] (step=0053100) Train Loss: nan, Train Steps/Sec: 1.33
|
| 552 |
+
[[34m2026-02-02 01:18:27[0m] (step=0053200) Train Loss: nan, Train Steps/Sec: 1.33
|
| 553 |
+
[[34m2026-02-02 01:19:42[0m] (step=0053300) Train Loss: nan, Train Steps/Sec: 1.33
|
| 554 |
+
[[34m2026-02-02 01:20:57[0m] (step=0053400) Train Loss: nan, Train Steps/Sec: 1.33
|
| 555 |
+
[[34m2026-02-02 01:22:13[0m] (step=0053500) Train Loss: nan, Train Steps/Sec: 1.33
|
| 556 |
+
[[34m2026-02-02 01:23:28[0m] (step=0053600) Train Loss: nan, Train Steps/Sec: 1.33
|
| 557 |
+
[[34m2026-02-02 01:24:43[0m] (step=0053700) Train Loss: nan, Train Steps/Sec: 1.33
|
| 558 |
+
[[34m2026-02-02 01:25:58[0m] (step=0053800) Train Loss: nan, Train Steps/Sec: 1.33
|
| 559 |
+
[[34m2026-02-02 01:27:13[0m] (step=0053900) Train Loss: nan, Train Steps/Sec: 1.33
|
| 560 |
+
[[34m2026-02-02 01:28:28[0m] (step=0054000) Train Loss: nan, Train Steps/Sec: 1.33
|
| 561 |
+
[[34m2026-02-02 01:29:44[0m] (step=0054100) Train Loss: nan, Train Steps/Sec: 1.33
|
| 562 |
+
[[34m2026-02-02 01:30:59[0m] (step=0054200) Train Loss: nan, Train Steps/Sec: 1.33
|
| 563 |
+
[[34m2026-02-02 01:32:14[0m] (step=0054300) Train Loss: nan, Train Steps/Sec: 1.33
|
| 564 |
+
[[34m2026-02-02 01:33:29[0m] (step=0054400) Train Loss: nan, Train Steps/Sec: 1.33
|
| 565 |
+
[[34m2026-02-02 01:34:44[0m] (step=0054500) Train Loss: nan, Train Steps/Sec: 1.33
|
| 566 |
+
[[34m2026-02-02 01:35:59[0m] (step=0054600) Train Loss: nan, Train Steps/Sec: 1.33
|
| 567 |
+
[[34m2026-02-02 01:37:15[0m] (step=0054700) Train Loss: nan, Train Steps/Sec: 1.33
|
| 568 |
+
[[34m2026-02-02 01:38:30[0m] (step=0054800) Train Loss: nan, Train Steps/Sec: 1.33
|
| 569 |
+
[[34m2026-02-02 01:39:45[0m] (step=0054900) Train Loss: nan, Train Steps/Sec: 1.33
|
| 570 |
+
[[34m2026-02-02 01:41:00[0m] (step=0055000) Train Loss: nan, Train Steps/Sec: 1.33
|
| 571 |
+
[[34m2026-02-02 01:42:15[0m] (step=0055100) Train Loss: nan, Train Steps/Sec: 1.33
|
| 572 |
+
[[34m2026-02-02 01:43:30[0m] (step=0055200) Train Loss: nan, Train Steps/Sec: 1.33
|
| 573 |
+
[[34m2026-02-02 01:44:46[0m] (step=0055300) Train Loss: nan, Train Steps/Sec: 1.33
|
| 574 |
+
[[34m2026-02-02 01:46:01[0m] (step=0055400) Train Loss: nan, Train Steps/Sec: 1.33
|
| 575 |
+
[[34m2026-02-02 01:47:16[0m] (step=0055500) Train Loss: nan, Train Steps/Sec: 1.33
|
| 576 |
+
[[34m2026-02-02 01:48:31[0m] (step=0055600) Train Loss: nan, Train Steps/Sec: 1.33
|
| 577 |
+
[[34m2026-02-02 01:49:46[0m] (step=0055700) Train Loss: nan, Train Steps/Sec: 1.33
|
| 578 |
+
[[34m2026-02-02 01:51:02[0m] (step=0055800) Train Loss: nan, Train Steps/Sec: 1.33
|
| 579 |
+
[[34m2026-02-02 01:52:17[0m] (step=0055900) Train Loss: nan, Train Steps/Sec: 1.33
|
| 580 |
+
[[34m2026-02-02 01:53:32[0m] (step=0056000) Train Loss: nan, Train Steps/Sec: 1.33
|
| 581 |
+
[[34m2026-02-02 01:54:47[0m] (step=0056100) Train Loss: nan, Train Steps/Sec: 1.33
|
| 582 |
+
[[34m2026-02-02 01:56:02[0m] (step=0056200) Train Loss: nan, Train Steps/Sec: 1.33
|
| 583 |
+
[[34m2026-02-02 01:57:17[0m] (step=0056300) Train Loss: nan, Train Steps/Sec: 1.33
|
| 584 |
+
[[34m2026-02-02 01:58:32[0m] (step=0056400) Train Loss: nan, Train Steps/Sec: 1.33
|
| 585 |
+
[[34m2026-02-02 01:59:48[0m] (step=0056500) Train Loss: nan, Train Steps/Sec: 1.33
|
| 586 |
+
[[34m2026-02-02 02:01:03[0m] (step=0056600) Train Loss: nan, Train Steps/Sec: 1.33
|
| 587 |
+
[[34m2026-02-02 02:02:18[0m] (step=0056700) Train Loss: nan, Train Steps/Sec: 1.33
|
| 588 |
+
[[34m2026-02-02 02:03:33[0m] (step=0056800) Train Loss: nan, Train Steps/Sec: 1.33
|
| 589 |
+
[[34m2026-02-02 02:04:48[0m] (step=0056900) Train Loss: nan, Train Steps/Sec: 1.33
|
| 590 |
+
[[34m2026-02-02 02:06:04[0m] (step=0057000) Train Loss: nan, Train Steps/Sec: 1.33
|
| 591 |
+
[[34m2026-02-02 02:07:19[0m] (step=0057100) Train Loss: nan, Train Steps/Sec: 1.33
|
| 592 |
+
[[34m2026-02-02 02:08:34[0m] (step=0057200) Train Loss: nan, Train Steps/Sec: 1.33
|
| 593 |
+
[[34m2026-02-02 02:09:49[0m] (step=0057300) Train Loss: nan, Train Steps/Sec: 1.33
|
| 594 |
+
[[34m2026-02-02 02:11:04[0m] (step=0057400) Train Loss: nan, Train Steps/Sec: 1.33
|
| 595 |
+
[[34m2026-02-02 02:12:19[0m] (step=0057500) Train Loss: nan, Train Steps/Sec: 1.33
|
| 596 |
+
[[34m2026-02-02 02:13:35[0m] (step=0057600) Train Loss: nan, Train Steps/Sec: 1.33
|
| 597 |
+
[[34m2026-02-02 02:14:50[0m] (step=0057700) Train Loss: nan, Train Steps/Sec: 1.33
|
| 598 |
+
[[34m2026-02-02 02:16:05[0m] (step=0057800) Train Loss: nan, Train Steps/Sec: 1.33
|
| 599 |
+
[[34m2026-02-02 02:17:20[0m] (step=0057900) Train Loss: nan, Train Steps/Sec: 1.33
|
| 600 |
+
[[34m2026-02-02 02:18:35[0m] (step=0058000) Train Loss: nan, Train Steps/Sec: 1.33
|
| 601 |
+
[[34m2026-02-02 02:19:50[0m] (step=0058100) Train Loss: nan, Train Steps/Sec: 1.33
|
| 602 |
+
[[34m2026-02-02 02:21:05[0m] (step=0058200) Train Loss: nan, Train Steps/Sec: 1.33
|
| 603 |
+
[[34m2026-02-02 02:22:21[0m] (step=0058300) Train Loss: nan, Train Steps/Sec: 1.33
|
| 604 |
+
[[34m2026-02-02 02:23:36[0m] (step=0058400) Train Loss: nan, Train Steps/Sec: 1.33
|
| 605 |
+
[[34m2026-02-02 02:24:51[0m] (step=0058500) Train Loss: nan, Train Steps/Sec: 1.33
|
| 606 |
+
[[34m2026-02-02 02:26:06[0m] (step=0058600) Train Loss: nan, Train Steps/Sec: 1.33
|
| 607 |
+
[[34m2026-02-02 02:27:21[0m] (step=0058700) Train Loss: nan, Train Steps/Sec: 1.33
|
| 608 |
+
[[34m2026-02-02 02:28:36[0m] (step=0058800) Train Loss: nan, Train Steps/Sec: 1.33
|
| 609 |
+
[[34m2026-02-02 02:29:51[0m] (step=0058900) Train Loss: nan, Train Steps/Sec: 1.33
|
| 610 |
+
[[34m2026-02-02 02:31:06[0m] (step=0059000) Train Loss: nan, Train Steps/Sec: 1.33
|
| 611 |
+
[[34m2026-02-02 02:32:21[0m] (step=0059100) Train Loss: nan, Train Steps/Sec: 1.33
|
| 612 |
+
[[34m2026-02-02 02:33:36[0m] (step=0059200) Train Loss: nan, Train Steps/Sec: 1.33
|
| 613 |
+
[[34m2026-02-02 02:34:51[0m] (step=0059300) Train Loss: nan, Train Steps/Sec: 1.33
|
| 614 |
+
[[34m2026-02-02 02:36:07[0m] (step=0059400) Train Loss: nan, Train Steps/Sec: 1.33
|
| 615 |
+
[[34m2026-02-02 02:38:31[0m] (step=0059500) Train Loss: nan, Train Steps/Sec: 0.69
|
| 616 |
+
[[34m2026-02-02 02:41:15[0m] (step=0059600) Train Loss: nan, Train Steps/Sec: 0.61
|
| 617 |
+
[[34m2026-02-02 02:43:58[0m] (step=0059700) Train Loss: nan, Train Steps/Sec: 0.61
|
| 618 |
+
[[34m2026-02-02 02:46:41[0m] (step=0059800) Train Loss: nan, Train Steps/Sec: 0.61
|
| 619 |
+
[[34m2026-02-02 02:49:24[0m] (step=0059900) Train Loss: nan, Train Steps/Sec: 0.61
|
| 620 |
+
[[34m2026-02-02 02:52:09[0m] (step=0060000) Train Loss: nan, Train Steps/Sec: 0.61
|
| 621 |
+
[[34m2026-02-02 02:53:37[0m] Beginning epoch 6...
|
| 622 |
+
[[34m2026-02-02 02:54:54[0m] (step=0060100) Train Loss: nan, Train Steps/Sec: 0.61
|
| 623 |
+
[[34m2026-02-02 02:57:37[0m] (step=0060200) Train Loss: nan, Train Steps/Sec: 0.61
|
| 624 |
+
[[34m2026-02-02 03:00:21[0m] (step=0060300) Train Loss: nan, Train Steps/Sec: 0.61
|
| 625 |
+
[[34m2026-02-02 03:03:04[0m] (step=0060400) Train Loss: nan, Train Steps/Sec: 0.61
|
| 626 |
+
[[34m2026-02-02 03:05:46[0m] (step=0060500) Train Loss: nan, Train Steps/Sec: 0.62
|
| 627 |
+
[[34m2026-02-02 03:08:31[0m] (step=0060600) Train Loss: nan, Train Steps/Sec: 0.61
|
| 628 |
+
[[34m2026-02-02 03:11:14[0m] (step=0060700) Train Loss: nan, Train Steps/Sec: 0.61
|
Rectified_Noise/GVP-Disp/权重类型分析.md
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 损失函数权重类型分析
|
| 2 |
+
|
| 3 |
+
## 代码位置
|
| 4 |
+
`transport/transport.py` 第150-156行
|
| 5 |
+
|
| 6 |
+
## 三种权重类型
|
| 7 |
+
|
| 8 |
+
### 1. WeightType.NONE
|
| 9 |
+
```python
|
| 10 |
+
weight = 1
|
| 11 |
+
```
|
| 12 |
+
**特点:**
|
| 13 |
+
- 均匀权重,所有时间步的损失贡献相同
|
| 14 |
+
- 最简单的权重策略
|
| 15 |
+
|
| 16 |
+
**影响:**
|
| 17 |
+
- 对训练过程的影响:所有时间步 t 的损失被同等对待
|
| 18 |
+
- 优点:实现简单,训练稳定
|
| 19 |
+
- 缺点:可能忽略不同时间步的重要性差异
|
| 20 |
+
|
| 21 |
+
### 2. WeightType.VELOCITY
|
| 22 |
+
```python
|
| 23 |
+
weight = (drift_var / sigma_t) ** 2
|
| 24 |
+
```
|
| 25 |
+
**特点:**
|
| 26 |
+
- 权重与 `(drift_var / sigma_t)²` 成正比
|
| 27 |
+
- `drift_var` 是扩散系数(diffusion coefficient)
|
| 28 |
+
- `sigma_t` 是噪声系数(noise coefficient)
|
| 29 |
+
|
| 30 |
+
**数学含义:**
|
| 31 |
+
- 对于线性路径(ICPlan):`sigma_t = 1 - t`,`drift_var` 是扩散项
|
| 32 |
+
- 权重 = `(扩散系数 / 噪声系数)²`
|
| 33 |
+
- 当 `sigma_t` 较小时(接近 t=1,噪声少),权重较大
|
| 34 |
+
- 当 `sigma_t` 较大时(接近 t=0,噪声多),权重较小
|
| 35 |
+
|
| 36 |
+
**影响:**
|
| 37 |
+
- **强调后期时间步**:在去噪过程的后期(t接近1,噪声少)给予更高权重
|
| 38 |
+
- **训练重点**:模型在低噪声区域的预测精度更重要
|
| 39 |
+
- **适用场景**:当最终生成质量(低噪声区域)是关键时
|
| 40 |
+
|
| 41 |
+
**时间依赖行为:**
|
| 42 |
+
- t → 0(高噪声):`sigma_t` 大 → 权重小
|
| 43 |
+
- t → 1(低噪声):`sigma_t` 小 → 权重大
|
| 44 |
+
|
| 45 |
+
### 3. WeightType.LIKELIHOOD
|
| 46 |
+
```python
|
| 47 |
+
weight = drift_var / (sigma_t ** 2)
|
| 48 |
+
```
|
| 49 |
+
**特点:**
|
| 50 |
+
- 权重与 `drift_var / sigma_t²` 成正比
|
| 51 |
+
- 相比 VELOCITY 权重,分母是 `sigma_t²` 而不是 `sigma_t`
|
| 52 |
+
|
| 53 |
+
**数学含义:**
|
| 54 |
+
- 权重 = `扩散系数 / 噪声系数²`
|
| 55 |
+
- 当 `sigma_t` 较小时,`sigma_t²` 更小,权重更大
|
| 56 |
+
- 当 `sigma_t` 较大时,`sigma_t²` 更大,权重更小
|
| 57 |
+
|
| 58 |
+
**影响:**
|
| 59 |
+
- **更强烈地强调后期时间步**:相比 VELOCITY,对低噪声区域的权重更大
|
| 60 |
+
- **训练重点**:极大化模型在低噪声区域的预测精度
|
| 61 |
+
- **适用场景**:当需要最大化似然或生成质量时
|
| 62 |
+
|
| 63 |
+
**与 VELOCITY 的对比:**
|
| 64 |
+
- LIKELIHOOD 权重 = VELOCITY 权重 × `(1 / sigma_t)`
|
| 65 |
+
- 在相同 `drift_var` 和 `sigma_t` 下,LIKELIHOOD 权重总是 ≥ VELOCITY 权重
|
| 66 |
+
- LIKELIHOOD 对后期时间步的强调更极端
|
| 67 |
+
|
| 68 |
+
## 权重随时间的典型行为(线性路径示例)
|
| 69 |
+
|
| 70 |
+
假设线性路径(ICPlan):
|
| 71 |
+
- `sigma_t = 1 - t`(从 1 到 0)
|
| 72 |
+
- `drift_var` 通常与 `t` 相关
|
| 73 |
+
|
| 74 |
+
### 时间步 t=0.1(高噪声)
|
| 75 |
+
- `sigma_t ≈ 0.9`
|
| 76 |
+
- **NONE**: weight = 1
|
| 77 |
+
- **VELOCITY**: weight = `(drift_var / 0.9)²` ≈ 中等
|
| 78 |
+
- **LIKELIHOOD**: weight = `drift_var / 0.81` ≈ 较大
|
| 79 |
+
|
| 80 |
+
### 时间步 t=0.5(中等噪声)
|
| 81 |
+
- `sigma_t = 0.5`
|
| 82 |
+
- **NONE**: weight = 1
|
| 83 |
+
- **VELOCITY**: weight = `(drift_var / 0.5)²` = `4 × drift_var²`
|
| 84 |
+
- **LIKELIHOOD**: weight = `drift_var / 0.25` = `4 × drift_var`
|
| 85 |
+
|
| 86 |
+
### 时间步 t=0.9(低噪声)
|
| 87 |
+
- `sigma_t ≈ 0.1`
|
| 88 |
+
- **NONE**: weight = 1
|
| 89 |
+
- **VELOCITY**: weight = `(drift_var / 0.1)²` = `100 × drift_var²`(很大)
|
| 90 |
+
- **LIKELIHOOD**: weight = `drift_var / 0.01` = `100 × drift_var`(非常大)
|
| 91 |
+
|
| 92 |
+
## 实际损失计算
|
| 93 |
+
|
| 94 |
+
### 对于 NOISE 模型类型:
|
| 95 |
+
```python
|
| 96 |
+
loss = mean_flat(weight * ((model_output - x0) ** 2))
|
| 97 |
+
```
|
| 98 |
+
|
| 99 |
+
### 对于 SCORE 模型类型:
|
| 100 |
+
```python
|
| 101 |
+
loss = mean_flat(weight * ((model_output * sigma_t + x0) ** 2))
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
## 选择建议
|
| 105 |
+
|
| 106 |
+
1. **WeightType.NONE**
|
| 107 |
+
- 适合:简单实验、基线对比
|
| 108 |
+
- 优点:训练稳定,实现简单
|
| 109 |
+
- 缺点:可能忽略时间步重要性
|
| 110 |
+
|
| 111 |
+
2. **WeightType.VELOCITY**
|
| 112 |
+
- 适合:关注最终生成质量
|
| 113 |
+
- 优点:强调低噪声区域,生成质量通常更好
|
| 114 |
+
- 缺点:可能在高噪声区域训练不足
|
| 115 |
+
|
| 116 |
+
3. **WeightType.LIKELIHOOD**
|
| 117 |
+
- 适合:需要最大化似然、追求最高生成质量
|
| 118 |
+
- 优点:最强调低噪声区域
|
| 119 |
+
- 缺点:可能在高噪声区域训练严重不足,训练可能不稳定
|
| 120 |
+
|
| 121 |
+
## 总结
|
| 122 |
+
|
| 123 |
+
三种权重类型形成了一个从均匀到极端强调后期的梯度:
|
| 124 |
+
|
| 125 |
+
```
|
| 126 |
+
NONE (均匀) < VELOCITY (强调后期) < LIKELIHOOD (极端强调后期)
|
| 127 |
+
```
|
| 128 |
+
|
| 129 |
+
选择哪种权重取决于:
|
| 130 |
+
- 训练目标(生成质量 vs 训练稳定性)
|
| 131 |
+
- 数据特性
|
| 132 |
+
- 模型类型(NOISE vs SCORE)
|
| 133 |
+
- 路径类型(LINEAR vs GVP vs VP)
|
Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000032.png
ADDED
|
Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000077.png
ADDED
|
Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000133.png
ADDED
|
Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000161.png
ADDED
|
Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000220.png
ADDED
|
Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000331.png
ADDED
|
Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000387.png
ADDED
|
Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000505.png
ADDED
|
Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000517.png
ADDED
|
Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000551.png
ADDED
|
Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000726.png
ADDED
|
Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000817.png
ADDED
|
Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000865.png
ADDED
|
Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000914.png
ADDED
|
Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000940.png
ADDED
|
Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/001043.png
ADDED
|
Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/001210.png
ADDED
|
SiT_back/SiT_clean/W_training.log
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
nohup: ignoring input
|
| 2 |
+
W1124 10:39:29.690000 58030 site-packages/torch/distributed/run.py:793]
|
| 3 |
+
W1124 10:39:29.690000 58030 site-packages/torch/distributed/run.py:793] *****************************************
|
| 4 |
+
W1124 10:39:29.690000 58030 site-packages/torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
|
| 5 |
+
W1124 10:39:29.690000 58030 site-packages/torch/distributed/run.py:793] *****************************************
|
| 6 |
+
[NOTICE] The application is pending for GPU resource in asynchronous queue. The longest waiting time in queue is 1800 seconds.
|
| 7 |
+
[NOTICE] The application is pending for GPU resource in asynchronous queue. The longest waiting time in queue is 1800 seconds.
|
| 8 |
+
[NOTICE] The application is pending for GPU resource in asynchronous queue. The longest waiting time in queue is 1800 seconds.
|
| 9 |
+
[NOTICE] The application is pending for GPU resource in asynchronous queue. The longest waiting time in queue is 1800 seconds.
|
| 10 |
+
Starting rank=0, seed=0, world_size=4.
|
| 11 |
+
[[34m2025-11-24 10:39:48[0m] Experiment directory created at results/005-SiT-XL-2-Linear-velocity-None
|
| 12 |
+
[[34m2025-11-24 10:39:48[0m] Sample images will be saved to results/005-SiT-XL-2-Linear-velocity-None/pic
|
| 13 |
+
Starting rank=2, seed=2, world_size=4.
|
| 14 |
+
Starting rank=1, seed=1, world_size=4.
|
| 15 |
+
Starting rank=3, seed=3, world_size=4.
|
| 16 |
+
[[34m2025-11-24 10:40:02[0m] SiT Parameters: 675,129,632
|
| 17 |
+
[[34m2025-11-24 10:40:04[0m] Dataset contains 1,281,167 images (/gemini/platform/public/hzh/datasets/Imagenet/train/)
|
| 18 |
+
[[34m2025-11-24 10:40:04[0m] Training for 140000 epochs...
|
| 19 |
+
[[34m2025-11-24 10:40:04[0m] Beginning epoch 0...
|
| 20 |
+
[[34m2025-11-24 10:40:24[0m] Saved checkpoint to results/005-SiT-XL-2-Linear-velocity-None/checkpoints/0000010.pt
|
| 21 |
+
[[34m2025-11-24 10:40:24[0m] Generating EMA samples...
|
| 22 |
+
[[34m2025-11-24 10:40:25[0m] Saved sample images grid to results/005-SiT-XL-2-Linear-velocity-None/pic/step_0000010_samples_grid.png
|
| 23 |
+
[[34m2025-11-24 10:40:25[0m] Generating EMA samples done.
|
| 24 |
+
W1124 10:40:39.173000 58030 site-packages/torch/distributed/elastic/agent/server/api.py:704] Received 2 death signal, shutting down workers
|
| 25 |
+
W1124 10:40:39.173000 58030 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 58079 closing signal SIGINT
|
| 26 |
+
W1124 10:40:39.174000 58030 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 58080 closing signal SIGINT
|
| 27 |
+
W1124 10:40:39.174000 58030 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 58081 closing signal SIGINT
|
| 28 |
+
W1124 10:40:39.174000 58030 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 58082 closing signal SIGINT
|
| 29 |
+
[rank0]: Traceback (most recent call last):
|
| 30 |
+
[rank0]: File "/gemini/space/gzy_new/Noise_Matching/SiT_clean/train.py", line 371, in <module>
|
| 31 |
+
[rank0]: main(args)
|
| 32 |
+
[rank0]: File "/gemini/space/gzy_new/Noise_Matching/SiT_clean/train.py", line 298, in main
|
| 33 |
+
[rank0]: torch.save(checkpoint, checkpoint_path)
|
| 34 |
+
[rank0]: File "/opt/conda/envs/SiT/lib/python3.12/site-packages/torch/serialization.py", line 850, in save
|
| 35 |
+
[rank0]: _save(
|
| 36 |
+
[rank0]: File "/opt/conda/envs/SiT/lib/python3.12/site-packages/torch/serialization.py", line 1114, in _save
|
| 37 |
+
[rank0]: zip_file.write_record(name, storage, num_bytes)
|
| 38 |
+
[rank0]: KeyboardInterrupt
|
| 39 |
+
W1124 10:40:39.380000 58030 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 58079 closing signal SIGTERM
|
| 40 |
+
W1124 10:40:39.380000 58030 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 58080 closing signal SIGTERM
|
| 41 |
+
W1124 10:40:39.381000 58030 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 58081 closing signal SIGTERM
|
| 42 |
+
W1124 10:40:39.381000 58030 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 58082 closing signal SIGTERM
|
| 43 |
+
Traceback (most recent call last):
|
| 44 |
+
File "/opt/conda/envs/SiT/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/api.py", line 696, in run
|
| 45 |
+
result = self._invoke_run(role)
|
| 46 |
+
^^^^^^^^^^^^^^^^^^^^^^
|
| 47 |
+
File "/opt/conda/envs/SiT/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/api.py", line 855, in _invoke_run
|
| 48 |
+
time.sleep(monitor_interval)
|
| 49 |
+
File "/opt/conda/envs/SiT/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 84, in _terminate_process_handler
|
| 50 |
+
raise SignalException(f"Process {os.getpid()} got signal: {sigval}", sigval=sigval)
|
| 51 |
+
torch.distributed.elastic.multiprocessing.api.SignalException: Process 58030 got signal: 2
|
| 52 |
+
|
| 53 |
+
During handling of the above exception, another exception occurred:
|
| 54 |
+
|
| 55 |
+
Traceback (most recent call last):
|
| 56 |
+
File "/opt/conda/envs/SiT/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/api.py", line 705, in run
|
| 57 |
+
self._shutdown(e.sigval)
|
| 58 |
+
File "/opt/conda/envs/SiT/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/local_elastic_agent.py", line 365, in _shutdown
|
| 59 |
+
self._pcontext.close(death_sig)
|
| 60 |
+
File "/opt/conda/envs/SiT/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 572, in close
|
| 61 |
+
self._close(death_sig=death_sig, timeout=timeout)
|
| 62 |
+
File "/opt/conda/envs/SiT/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 909, in _close
|
| 63 |
+
handler.proc.wait(time_to_wait)
|
| 64 |
+
File "/opt/conda/envs/SiT/lib/python3.12/subprocess.py", line 1266, in wait
|
| 65 |
+
return self._wait(timeout=timeout)
|
| 66 |
+
^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 67 |
+
File "/opt/conda/envs/SiT/lib/python3.12/subprocess.py", line 2055, in _wait
|
| 68 |
+
time.sleep(delay)
|
| 69 |
+
File "/opt/conda/envs/SiT/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 84, in _terminate_process_handler
|
| 70 |
+
raise SignalException(f"Process {os.getpid()} got signal: {sigval}", sigval=sigval)
|
| 71 |
+
torch.distributed.elastic.multiprocessing.api.SignalException: Process 58030 got signal: 2
|
| 72 |
+
|
| 73 |
+
During handling of the above exception, another exception occurred:
|
| 74 |
+
|
| 75 |
+
Traceback (most recent call last):
|
| 76 |
+
File "/opt/conda/envs/SiT/bin/torchrun", line 33, in <module>
|
| 77 |
+
sys.exit(load_entry_point('torch==2.5.1', 'console_scripts', 'torchrun')())
|
| 78 |
+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 79 |
+
File "/opt/conda/envs/SiT/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper
|
| 80 |
+
return f(*args, **kwargs)
|
| 81 |
+
^^^^^^^^^^^^^^^^^^
|
| 82 |
+
File "/opt/conda/envs/SiT/lib/python3.12/site-packages/torch/distributed/run.py", line 919, in main
|
| 83 |
+
run(args)
|
| 84 |
+
File "/opt/conda/envs/SiT/lib/python3.12/site-packages/torch/distributed/run.py", line 910, in run
|
| 85 |
+
elastic_launch(
|
| 86 |
+
File "/opt/conda/envs/SiT/lib/python3.12/site-packages/torch/distributed/launcher/api.py", line 138, in __call__
|
| 87 |
+
return launch_agent(self._config, self._entrypoint, list(args))
|
| 88 |
+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 89 |
+
File "/opt/conda/envs/SiT/lib/python3.12/site-packages/torch/distributed/launcher/api.py", line 260, in launch_agent
|
| 90 |
+
result = agent.run()
|
| 91 |
+
^^^^^^^^^^^
|
| 92 |
+
File "/opt/conda/envs/SiT/lib/python3.12/site-packages/torch/distributed/elastic/metrics/api.py", line 137, in wrapper
|
| 93 |
+
result = f(*args, **kwargs)
|
| 94 |
+
^^^^^^^^^^^^^^^^^^
|
| 95 |
+
File "/opt/conda/envs/SiT/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/api.py", line 710, in run
|
| 96 |
+
self._shutdown()
|
| 97 |
+
File "/opt/conda/envs/SiT/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/local_elastic_agent.py", line 365, in _shutdown
|
| 98 |
+
self._pcontext.close(death_sig)
|
| 99 |
+
File "/opt/conda/envs/SiT/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 572, in close
|
| 100 |
+
self._close(death_sig=death_sig, timeout=timeout)
|
| 101 |
+
File "/opt/conda/envs/SiT/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 909, in _close
|
| 102 |
+
handler.proc.wait(time_to_wait)
|
| 103 |
+
File "/opt/conda/envs/SiT/lib/python3.12/subprocess.py", line 1266, in wait
|
| 104 |
+
return self._wait(timeout=timeout)
|
| 105 |
+
^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 106 |
+
File "/opt/conda/envs/SiT/lib/python3.12/subprocess.py", line 2055, in _wait
|
| 107 |
+
time.sleep(delay)
|
| 108 |
+
File "/opt/conda/envs/SiT/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 84, in _terminate_process_handler
|
| 109 |
+
raise SignalException(f"Process {os.getpid()} got signal: {sigval}", sigval=sigval)
|
| 110 |
+
torch.distributed.elastic.multiprocessing.api.SignalException: Process 58030 got signal: 2
|
SiT_back/SiT_clean/__pycache__/download.cpython-312.pyc
ADDED
|
Binary file (1.99 kB). View file
|
|
|
SiT_back/SiT_clean/__pycache__/models.cpython-312.pyc
ADDED
|
Binary file (20.8 kB). View file
|
|
|
SiT_back/SiT_clean/__pycache__/train_utils.cpython-312.pyc
ADDED
|
Binary file (2.84 kB). View file
|
|
|
SiT_back/SiT_clean/download.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This source code is licensed under the license found in the
|
| 2 |
+
# LICENSE file in the root directory of this source tree.
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
Functions for downloading pre-trained SiT models
|
| 6 |
+
"""
|
| 7 |
+
from torchvision.datasets.utils import download_url
|
| 8 |
+
import torch
|
| 9 |
+
import os
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
pretrained_models = {'SiT-XL-2-256x256.pt'}
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def find_model(model_name):
|
| 16 |
+
"""
|
| 17 |
+
Finds a pre-trained SiT model, downloading it if necessary. Alternatively, loads a model from a local path.
|
| 18 |
+
"""
|
| 19 |
+
if model_name in pretrained_models:
|
| 20 |
+
return download_model(model_name)
|
| 21 |
+
else:
|
| 22 |
+
assert os.path.isfile(model_name), f'Could not find SiT checkpoint at {model_name}'
|
| 23 |
+
checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage)
|
| 24 |
+
if "ema" in checkpoint: # supports checkpoints from train.py
|
| 25 |
+
checkpoint = checkpoint["ema"]
|
| 26 |
+
return checkpoint
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def download_model(model_name):
|
| 30 |
+
"""
|
| 31 |
+
Downloads a pre-trained SiT model from the web.
|
| 32 |
+
"""
|
| 33 |
+
assert model_name in pretrained_models
|
| 34 |
+
local_path = f'pretrained_models/{model_name}'
|
| 35 |
+
if not os.path.isfile(local_path):
|
| 36 |
+
os.makedirs('pretrained_models', exist_ok=True)
|
| 37 |
+
web_path = f'https://www.dl.dropboxusercontent.com/scl/fi/as9oeomcbub47de5g4be0/SiT-XL-2-256.pt?rlkey=uxzxmpicu46coq3msb17b9ofa&dl=0'
|
| 38 |
+
download_url(web_path, 'pretrained_models', filename=model_name)
|
| 39 |
+
model = torch.load(local_path, map_location=lambda storage, loc: storage)
|
| 40 |
+
return model
|
SiT_back/SiT_clean/models.py
ADDED
|
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This source code is licensed under the license found in the
|
| 2 |
+
# LICENSE file in the root directory of this source tree.
|
| 3 |
+
# --------------------------------------------------------
|
| 4 |
+
# References:
|
| 5 |
+
# GLIDE: https://github.com/openai/glide-text2im
|
| 6 |
+
# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
|
| 7 |
+
# --------------------------------------------------------
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import numpy as np
|
| 12 |
+
import math
|
| 13 |
+
from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def modulate(x, shift, scale):
|
| 17 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
#################################################################################
|
| 21 |
+
# Embedding Layers for Timesteps and Class Labels #
|
| 22 |
+
#################################################################################
|
| 23 |
+
|
| 24 |
+
class TimestepEmbedder(nn.Module):
|
| 25 |
+
"""
|
| 26 |
+
Embeds scalar timesteps into vector representations.
|
| 27 |
+
"""
|
| 28 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.mlp = nn.Sequential(
|
| 31 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
| 32 |
+
nn.SiLU(),
|
| 33 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
| 34 |
+
)
|
| 35 |
+
self.frequency_embedding_size = frequency_embedding_size
|
| 36 |
+
|
| 37 |
+
@staticmethod
|
| 38 |
+
def timestep_embedding(t, dim, max_period=10000):
|
| 39 |
+
"""
|
| 40 |
+
Create sinusoidal timestep embeddings.
|
| 41 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
| 42 |
+
These may be fractional.
|
| 43 |
+
:param dim: the dimension of the output.
|
| 44 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
| 45 |
+
:return: an (N, D) Tensor of positional embeddings.
|
| 46 |
+
"""
|
| 47 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
| 48 |
+
half = dim // 2
|
| 49 |
+
freqs = torch.exp(
|
| 50 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
| 51 |
+
).to(device=t.device)
|
| 52 |
+
args = t[:, None].float() * freqs[None]
|
| 53 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 54 |
+
if dim % 2:
|
| 55 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 56 |
+
return embedding
|
| 57 |
+
|
| 58 |
+
def forward(self, t):
|
| 59 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
| 60 |
+
t_emb = self.mlp(t_freq)
|
| 61 |
+
return t_emb
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class LabelEmbedder(nn.Module):
|
| 65 |
+
"""
|
| 66 |
+
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
| 67 |
+
"""
|
| 68 |
+
def __init__(self, num_classes, hidden_size, dropout_prob):
|
| 69 |
+
super().__init__()
|
| 70 |
+
use_cfg_embedding = dropout_prob > 0
|
| 71 |
+
self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
|
| 72 |
+
self.num_classes = num_classes
|
| 73 |
+
self.dropout_prob = dropout_prob
|
| 74 |
+
|
| 75 |
+
def token_drop(self, labels, force_drop_ids=None):
|
| 76 |
+
"""
|
| 77 |
+
Drops labels to enable classifier-free guidance.
|
| 78 |
+
"""
|
| 79 |
+
if force_drop_ids is None:
|
| 80 |
+
drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
|
| 81 |
+
else:
|
| 82 |
+
drop_ids = force_drop_ids == 1
|
| 83 |
+
labels = torch.where(drop_ids, self.num_classes, labels)
|
| 84 |
+
return labels
|
| 85 |
+
|
| 86 |
+
def forward(self, labels, train, force_drop_ids=None):
|
| 87 |
+
use_dropout = self.dropout_prob > 0
|
| 88 |
+
if (train and use_dropout) or (force_drop_ids is not None):
|
| 89 |
+
labels = self.token_drop(labels, force_drop_ids)
|
| 90 |
+
embeddings = self.embedding_table(labels)
|
| 91 |
+
return embeddings
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
#################################################################################
|
| 95 |
+
# Core SiT Model #
|
| 96 |
+
#################################################################################
|
| 97 |
+
|
| 98 |
+
class SiTBlock(nn.Module):
|
| 99 |
+
"""
|
| 100 |
+
A SiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
|
| 101 |
+
"""
|
| 102 |
+
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
|
| 103 |
+
super().__init__()
|
| 104 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 105 |
+
self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
|
| 106 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 107 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 108 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
| 109 |
+
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
|
| 110 |
+
self.adaLN_modulation = nn.Sequential(
|
| 111 |
+
nn.SiLU(),
|
| 112 |
+
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
def forward(self, x, c):
|
| 116 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
|
| 117 |
+
x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
|
| 118 |
+
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
| 119 |
+
return x
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class FinalLayer(nn.Module):
|
| 123 |
+
"""
|
| 124 |
+
The final layer of SiT.
|
| 125 |
+
"""
|
| 126 |
+
def __init__(self, hidden_size, patch_size, out_channels):
|
| 127 |
+
super().__init__()
|
| 128 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 129 |
+
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
| 130 |
+
self.adaLN_modulation = nn.Sequential(
|
| 131 |
+
nn.SiLU(),
|
| 132 |
+
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
def forward(self, x, c):
|
| 136 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
| 137 |
+
x = modulate(self.norm_final(x), shift, scale)
|
| 138 |
+
x = self.linear(x)
|
| 139 |
+
return x
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class SiT(nn.Module):
|
| 143 |
+
"""
|
| 144 |
+
Diffusion model with a Transformer backbone.
|
| 145 |
+
"""
|
| 146 |
+
def __init__(
|
| 147 |
+
self,
|
| 148 |
+
input_size=32,
|
| 149 |
+
patch_size=2,
|
| 150 |
+
in_channels=4,
|
| 151 |
+
hidden_size=1152,
|
| 152 |
+
depth=28,
|
| 153 |
+
num_heads=16,
|
| 154 |
+
mlp_ratio=4.0,
|
| 155 |
+
class_dropout_prob=0.1,
|
| 156 |
+
num_classes=1000,
|
| 157 |
+
learn_sigma=True,
|
| 158 |
+
):
|
| 159 |
+
super().__init__()
|
| 160 |
+
self.learn_sigma = learn_sigma
|
| 161 |
+
self.learn_sigma = True
|
| 162 |
+
self.in_channels = in_channels
|
| 163 |
+
self.out_channels = in_channels * 2
|
| 164 |
+
self.patch_size = patch_size
|
| 165 |
+
self.num_heads = num_heads
|
| 166 |
+
|
| 167 |
+
self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
|
| 168 |
+
self.t_embedder = TimestepEmbedder(hidden_size)
|
| 169 |
+
self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
|
| 170 |
+
num_patches = self.x_embedder.num_patches
|
| 171 |
+
# Will use fixed sin-cos embedding:
|
| 172 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
|
| 173 |
+
|
| 174 |
+
self.blocks = nn.ModuleList([
|
| 175 |
+
SiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
|
| 176 |
+
])
|
| 177 |
+
self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
|
| 178 |
+
self.initialize_weights()
|
| 179 |
+
|
| 180 |
+
def initialize_weights(self):
|
| 181 |
+
# Initialize transformer layers:
|
| 182 |
+
def _basic_init(module):
|
| 183 |
+
if isinstance(module, nn.Linear):
|
| 184 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 185 |
+
if module.bias is not None:
|
| 186 |
+
nn.init.constant_(module.bias, 0)
|
| 187 |
+
self.apply(_basic_init)
|
| 188 |
+
|
| 189 |
+
# Initialize (and freeze) pos_embed by sin-cos embedding:
|
| 190 |
+
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
|
| 191 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
| 192 |
+
|
| 193 |
+
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
| 194 |
+
w = self.x_embedder.proj.weight.data
|
| 195 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
| 196 |
+
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
| 197 |
+
|
| 198 |
+
# Initialize label embedding table:
|
| 199 |
+
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
|
| 200 |
+
|
| 201 |
+
# Initialize timestep embedding MLP:
|
| 202 |
+
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
| 203 |
+
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
| 204 |
+
|
| 205 |
+
# Zero-out adaLN modulation layers in SiT blocks:
|
| 206 |
+
for block in self.blocks:
|
| 207 |
+
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
| 208 |
+
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
| 209 |
+
|
| 210 |
+
# Zero-out output layers:
|
| 211 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
| 212 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
| 213 |
+
nn.init.constant_(self.final_layer.linear.weight, 0)
|
| 214 |
+
nn.init.constant_(self.final_layer.linear.bias, 0)
|
| 215 |
+
|
| 216 |
+
def unpatchify(self, x):
|
| 217 |
+
"""
|
| 218 |
+
x: (N, T, patch_size**2 * C)
|
| 219 |
+
imgs: (N, H, W, C)
|
| 220 |
+
"""
|
| 221 |
+
c = self.out_channels
|
| 222 |
+
p = self.x_embedder.patch_size[0]
|
| 223 |
+
h = w = int(x.shape[1] ** 0.5)
|
| 224 |
+
assert h * w == x.shape[1]
|
| 225 |
+
|
| 226 |
+
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
| 227 |
+
x = torch.einsum('nhwpqc->nchpwq', x)
|
| 228 |
+
imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
|
| 229 |
+
return imgs
|
| 230 |
+
|
| 231 |
+
def forward(self, x, t, y):
|
| 232 |
+
"""
|
| 233 |
+
Forward pass of SiT.
|
| 234 |
+
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
| 235 |
+
t: (N,) tensor of diffusion timesteps
|
| 236 |
+
y: (N,) tensor of class labels
|
| 237 |
+
"""
|
| 238 |
+
x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
|
| 239 |
+
t = self.t_embedder(t) # (N, D)
|
| 240 |
+
y = self.y_embedder(y, self.training) # (N, D)
|
| 241 |
+
c = t + y # (N, D)
|
| 242 |
+
for block in self.blocks:
|
| 243 |
+
x = block(x, c) # (N, T, D)
|
| 244 |
+
x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
|
| 245 |
+
x = self.unpatchify(x) # (N, out_channels, H, W)
|
| 246 |
+
if self.learn_sigma:
|
| 247 |
+
x, _ = x.chunk(2, dim=1)
|
| 248 |
+
return x
|
| 249 |
+
|
| 250 |
+
def forward_with_cfg(self, x, t, y, cfg_scale):
|
| 251 |
+
"""
|
| 252 |
+
Forward pass of SiT, but also batches the unconSiTional forward pass for classifier-free guidance.
|
| 253 |
+
"""
|
| 254 |
+
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
|
| 255 |
+
half = x[: len(x) // 2]
|
| 256 |
+
combined = torch.cat([half, half], dim=0)
|
| 257 |
+
model_out = self.forward(combined, t, y)
|
| 258 |
+
# For exact reproducibility reasons, we apply classifier-free guidance on only
|
| 259 |
+
# three channels by default. The standard approach to cfg applies it to all channels.
|
| 260 |
+
# This can be done by uncommenting the following line and commenting-out the line following that.
|
| 261 |
+
# eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
|
| 262 |
+
eps, rest = model_out[:, :3], model_out[:, 3:]
|
| 263 |
+
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
|
| 264 |
+
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
|
| 265 |
+
eps = torch.cat([half_eps, half_eps], dim=0)
|
| 266 |
+
return torch.cat([eps, rest], dim=1)
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
#################################################################################
|
| 270 |
+
# Sine/Cosine Positional Embedding Functions #
|
| 271 |
+
#################################################################################
|
| 272 |
+
# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
|
| 273 |
+
|
| 274 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
|
| 275 |
+
"""
|
| 276 |
+
grid_size: int of the grid height and width
|
| 277 |
+
return:
|
| 278 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
| 279 |
+
"""
|
| 280 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
| 281 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
| 282 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
| 283 |
+
grid = np.stack(grid, axis=0)
|
| 284 |
+
|
| 285 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
| 286 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| 287 |
+
if cls_token and extra_tokens > 0:
|
| 288 |
+
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
| 289 |
+
return pos_embed
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
| 293 |
+
assert embed_dim % 2 == 0
|
| 294 |
+
|
| 295 |
+
# use half of dimensions to encode grid_h
|
| 296 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
| 297 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
| 298 |
+
|
| 299 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
| 300 |
+
return emb
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
| 304 |
+
"""
|
| 305 |
+
embed_dim: output dimension for each position
|
| 306 |
+
pos: a list of positions to be encoded: size (M,)
|
| 307 |
+
out: (M, D)
|
| 308 |
+
"""
|
| 309 |
+
assert embed_dim % 2 == 0
|
| 310 |
+
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
| 311 |
+
omega /= embed_dim / 2.
|
| 312 |
+
omega = 1. / 10000**omega # (D/2,)
|
| 313 |
+
|
| 314 |
+
pos = pos.reshape(-1) # (M,)
|
| 315 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
| 316 |
+
|
| 317 |
+
emb_sin = np.sin(out) # (M, D/2)
|
| 318 |
+
emb_cos = np.cos(out) # (M, D/2)
|
| 319 |
+
|
| 320 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
| 321 |
+
return emb
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
#################################################################################
|
| 325 |
+
# SiT Configs #
|
| 326 |
+
#################################################################################
|
| 327 |
+
|
| 328 |
+
def SiT_XL_2(**kwargs):
|
| 329 |
+
return SiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
|
| 330 |
+
|
| 331 |
+
def SiT_XL_4(**kwargs):
|
| 332 |
+
return SiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs)
|
| 333 |
+
|
| 334 |
+
def SiT_XL_8(**kwargs):
|
| 335 |
+
return SiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs)
|
| 336 |
+
|
| 337 |
+
def SiT_L_2(**kwargs):
|
| 338 |
+
return SiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs)
|
| 339 |
+
|
| 340 |
+
def SiT_L_4(**kwargs):
|
| 341 |
+
return SiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs)
|
| 342 |
+
|
| 343 |
+
def SiT_L_8(**kwargs):
|
| 344 |
+
return SiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs)
|
| 345 |
+
|
| 346 |
+
def SiT_B_2(**kwargs):
|
| 347 |
+
return SiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs)
|
| 348 |
+
|
| 349 |
+
def SiT_B_4(**kwargs):
|
| 350 |
+
return SiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs)
|
| 351 |
+
|
| 352 |
+
def SiT_B_8(**kwargs):
|
| 353 |
+
return SiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs)
|
| 354 |
+
|
| 355 |
+
def SiT_S_2(**kwargs):
|
| 356 |
+
return SiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
|
| 357 |
+
|
| 358 |
+
def SiT_S_4(**kwargs):
|
| 359 |
+
return SiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs)
|
| 360 |
+
|
| 361 |
+
def SiT_S_8(**kwargs):
|
| 362 |
+
return SiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
SiT_models = {
|
| 366 |
+
'SiT-XL/2': SiT_XL_2, 'SiT-XL/4': SiT_XL_4, 'SiT-XL/8': SiT_XL_8,
|
| 367 |
+
'SiT-L/2': SiT_L_2, 'SiT-L/4': SiT_L_4, 'SiT-L/8': SiT_L_8,
|
| 368 |
+
'SiT-B/2': SiT_B_2, 'SiT-B/4': SiT_B_4, 'SiT-B/8': SiT_B_8,
|
| 369 |
+
'SiT-S/2': SiT_S_2, 'SiT-S/4': SiT_S_4, 'SiT-S/8': SiT_S_8,
|
| 370 |
+
}
|
SiT_back/SiT_clean/run.sh
ADDED
|
File without changes
|
SiT_back/SiT_clean/sample.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This source code is licensed under the license found in the
|
| 2 |
+
# LICENSE file in the root directory of this source tree.
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
Sample new images from a pre-trained SiT.
|
| 6 |
+
"""
|
| 7 |
+
import torch
|
| 8 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 9 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 10 |
+
from torchvision.utils import save_image
|
| 11 |
+
from diffusers.models import AutoencoderKL
|
| 12 |
+
from download import find_model
|
| 13 |
+
from models import SiT_models
|
| 14 |
+
from train_utils import parse_ode_args, parse_sde_args, parse_transport_args
|
| 15 |
+
from transport import create_transport, Sampler
|
| 16 |
+
import argparse
|
| 17 |
+
import sys
|
| 18 |
+
from time import time
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def main(mode, args):
|
| 22 |
+
# Setup PyTorch:
|
| 23 |
+
torch.manual_seed(args.seed)
|
| 24 |
+
torch.set_grad_enabled(False)
|
| 25 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 26 |
+
|
| 27 |
+
if args.ckpt is None:
|
| 28 |
+
assert args.model == "SiT-XL/2", "Only SiT-XL/2 models are available for auto-download."
|
| 29 |
+
assert args.image_size in [256, 512]
|
| 30 |
+
assert args.num_classes == 1000
|
| 31 |
+
assert args.image_size == 256, "512x512 models are not yet available for auto-download." # remove this line when 512x512 models are available
|
| 32 |
+
learn_sigma = args.image_size == 256
|
| 33 |
+
else:
|
| 34 |
+
learn_sigma = False
|
| 35 |
+
|
| 36 |
+
# Load model:
|
| 37 |
+
latent_size = args.image_size // 8
|
| 38 |
+
model = SiT_models[args.model](
|
| 39 |
+
input_size=latent_size,
|
| 40 |
+
num_classes=args.num_classes,
|
| 41 |
+
learn_sigma=learn_sigma,
|
| 42 |
+
).to(device)
|
| 43 |
+
# Auto-download a pre-trained model or load a custom SiT checkpoint from train.py:
|
| 44 |
+
ckpt_path = args.ckpt or f"SiT-XL-2-{args.image_size}x{args.image_size}.pt"
|
| 45 |
+
state_dict = find_model(ckpt_path)
|
| 46 |
+
model.load_state_dict(state_dict)
|
| 47 |
+
model.eval() # important!
|
| 48 |
+
transport = create_transport(
|
| 49 |
+
args.path_type,
|
| 50 |
+
args.prediction,
|
| 51 |
+
args.loss_weight,
|
| 52 |
+
args.train_eps,
|
| 53 |
+
args.sample_eps
|
| 54 |
+
)
|
| 55 |
+
sampler = Sampler(transport)
|
| 56 |
+
if mode == "ODE":
|
| 57 |
+
if args.likelihood:
|
| 58 |
+
assert args.cfg_scale == 1, "Likelihood is incompatible with guidance"
|
| 59 |
+
sample_fn = sampler.sample_ode_likelihood(
|
| 60 |
+
sampling_method=args.sampling_method,
|
| 61 |
+
num_steps=args.num_sampling_steps,
|
| 62 |
+
atol=args.atol,
|
| 63 |
+
rtol=args.rtol,
|
| 64 |
+
)
|
| 65 |
+
else:
|
| 66 |
+
sample_fn = sampler.sample_ode(
|
| 67 |
+
sampling_method=args.sampling_method,
|
| 68 |
+
num_steps=args.num_sampling_steps,
|
| 69 |
+
atol=args.atol,
|
| 70 |
+
rtol=args.rtol,
|
| 71 |
+
reverse=args.reverse
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
elif mode == "SDE":
|
| 75 |
+
sample_fn = sampler.sample_sde(
|
| 76 |
+
sampling_method=args.sampling_method,
|
| 77 |
+
diffusion_form=args.diffusion_form,
|
| 78 |
+
diffusion_norm=args.diffusion_norm,
|
| 79 |
+
last_step=args.last_step,
|
| 80 |
+
last_step_size=args.last_step_size,
|
| 81 |
+
num_steps=args.num_sampling_steps,
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device)
|
| 86 |
+
|
| 87 |
+
# Labels to condition the model with (feel free to change):
|
| 88 |
+
class_labels = [207, 360, 387, 974, 88, 979, 417, 279]
|
| 89 |
+
|
| 90 |
+
# Create sampling noise:
|
| 91 |
+
n = len(class_labels)
|
| 92 |
+
z = torch.randn(n, 4, latent_size, latent_size, device=device)
|
| 93 |
+
y = torch.tensor(class_labels, device=device)
|
| 94 |
+
|
| 95 |
+
# Setup classifier-free guidance:
|
| 96 |
+
z = torch.cat([z, z], 0)
|
| 97 |
+
y_null = torch.tensor([1000] * n, device=device)
|
| 98 |
+
y = torch.cat([y, y_null], 0)
|
| 99 |
+
model_kwargs = dict(y=y, cfg_scale=args.cfg_scale)
|
| 100 |
+
|
| 101 |
+
# Sample images:
|
| 102 |
+
start_time = time()
|
| 103 |
+
samples = sample_fn(z, model.forward_with_cfg, **model_kwargs)[-1]
|
| 104 |
+
samples, _ = samples.chunk(2, dim=0) # Remove null class samples
|
| 105 |
+
samples = vae.decode(samples / 0.18215).sample
|
| 106 |
+
print(f"Sampling took {time() - start_time:.2f} seconds.")
|
| 107 |
+
|
| 108 |
+
# Save and display images:
|
| 109 |
+
save_image(samples, "sample.png", nrow=4, normalize=True, value_range=(-1, 1))
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
if __name__ == "__main__":
|
| 113 |
+
parser = argparse.ArgumentParser()
|
| 114 |
+
|
| 115 |
+
if len(sys.argv) < 2:
|
| 116 |
+
print("Usage: program.py <mode> [options]")
|
| 117 |
+
sys.exit(1)
|
| 118 |
+
|
| 119 |
+
mode = sys.argv[1]
|
| 120 |
+
|
| 121 |
+
assert mode[:2] != "--", "Usage: program.py <mode> [options]"
|
| 122 |
+
assert mode in ["ODE", "SDE"], "Invalid mode. Please choose 'ODE' or 'SDE'"
|
| 123 |
+
|
| 124 |
+
parser.add_argument("--model", type=str, choices=list(SiT_models.keys()), default="SiT-XL/2")
|
| 125 |
+
parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="mse")
|
| 126 |
+
parser.add_argument("--image-size", type=int, choices=[256, 512], default=256)
|
| 127 |
+
parser.add_argument("--num-classes", type=int, default=1000)
|
| 128 |
+
parser.add_argument("--cfg-scale", type=float, default=4.0)
|
| 129 |
+
parser.add_argument("--num-sampling-steps", type=int, default=250)
|
| 130 |
+
parser.add_argument("--seed", type=int, default=0)
|
| 131 |
+
parser.add_argument("--ckpt", type=str, default=None,
|
| 132 |
+
help="Optional path to a SiT checkpoint (default: auto-download a pre-trained SiT-XL/2 model).")
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
parse_transport_args(parser)
|
| 136 |
+
if mode == "ODE":
|
| 137 |
+
parse_ode_args(parser)
|
| 138 |
+
# Further processing for ODE
|
| 139 |
+
elif mode == "SDE":
|
| 140 |
+
parse_sde_args(parser)
|
| 141 |
+
# Further processing for SDE
|
| 142 |
+
|
| 143 |
+
args = parser.parse_known_args()[0]
|
| 144 |
+
main(mode, args)
|
SiT_back/SiT_clean/sample_ddp.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This source code is licensed under the license found in the
|
| 2 |
+
# LICENSE file in the root directory of this source tree.
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
Samples a large number of images from a pre-trained SiT model using DDP.
|
| 6 |
+
Subsequently saves a .npz file that can be used to compute FID and other
|
| 7 |
+
evaluation metrics via the ADM repo: https://github.com/openai/guided-diffusion/tree/main/evaluations
|
| 8 |
+
|
| 9 |
+
For a simple single-GPU/CPU sampling script, see sample.py.
|
| 10 |
+
"""
|
| 11 |
+
import torch
|
| 12 |
+
import torch.distributed as dist
|
| 13 |
+
from models import SiT_models
|
| 14 |
+
from download import find_model
|
| 15 |
+
from transport import create_transport, Sampler
|
| 16 |
+
from diffusers.models import AutoencoderKL
|
| 17 |
+
from train_utils import parse_ode_args, parse_sde_args, parse_transport_args
|
| 18 |
+
from tqdm import tqdm
|
| 19 |
+
import os
|
| 20 |
+
from PIL import Image
|
| 21 |
+
import numpy as np
|
| 22 |
+
import math
|
| 23 |
+
import argparse
|
| 24 |
+
import sys
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def create_npz_from_sample_folder(sample_dir, num=50_000):
|
| 28 |
+
"""
|
| 29 |
+
Builds a single .npz file from a folder of .png samples.
|
| 30 |
+
"""
|
| 31 |
+
samples = []
|
| 32 |
+
for i in tqdm(range(num), desc="Building .npz file from samples"):
|
| 33 |
+
sample_pil = Image.open(f"{sample_dir}/{i:06d}.png")
|
| 34 |
+
sample_np = np.asarray(sample_pil).astype(np.uint8)
|
| 35 |
+
samples.append(sample_np)
|
| 36 |
+
samples = np.stack(samples)
|
| 37 |
+
assert samples.shape == (num, samples.shape[1], samples.shape[2], 3)
|
| 38 |
+
npz_path = f"{sample_dir}.npz"
|
| 39 |
+
np.savez(npz_path, arr_0=samples)
|
| 40 |
+
print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
|
| 41 |
+
return npz_path
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def main(mode, args):
|
| 45 |
+
"""
|
| 46 |
+
Run sampling.
|
| 47 |
+
"""
|
| 48 |
+
torch.backends.cuda.matmul.allow_tf32 = args.tf32 # True: fast but may lead to some small numerical differences
|
| 49 |
+
assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
|
| 50 |
+
torch.set_grad_enabled(False)
|
| 51 |
+
|
| 52 |
+
# Setup DDP:
|
| 53 |
+
dist.init_process_group("nccl")
|
| 54 |
+
rank = dist.get_rank()
|
| 55 |
+
device = rank % torch.cuda.device_count()
|
| 56 |
+
seed = args.global_seed * dist.get_world_size() + rank
|
| 57 |
+
torch.manual_seed(seed)
|
| 58 |
+
torch.cuda.set_device(device)
|
| 59 |
+
print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
|
| 60 |
+
|
| 61 |
+
if args.ckpt is None:
|
| 62 |
+
assert args.model == "SiT-XL/2", "Only SiT-XL/2 models are available for auto-download."
|
| 63 |
+
assert args.image_size in [256, 512]
|
| 64 |
+
assert args.num_classes == 1000
|
| 65 |
+
assert args.image_size == 256, "512x512 models are not yet available for auto-download." # remove this line when 512x512 models are available
|
| 66 |
+
learn_sigma = args.image_size == 256
|
| 67 |
+
else:
|
| 68 |
+
learn_sigma = False
|
| 69 |
+
|
| 70 |
+
# Load model:
|
| 71 |
+
latent_size = args.image_size // 8
|
| 72 |
+
model = SiT_models[args.model](
|
| 73 |
+
input_size=latent_size,
|
| 74 |
+
num_classes=args.num_classes,
|
| 75 |
+
learn_sigma=learn_sigma,
|
| 76 |
+
).to(device)
|
| 77 |
+
# Auto-download a pre-trained model or load a custom SiT checkpoint from train.py:
|
| 78 |
+
ckpt_path = args.ckpt or f"SiT-XL-2-{args.image_size}x{args.image_size}.pt"
|
| 79 |
+
state_dict = find_model(ckpt_path)
|
| 80 |
+
model.load_state_dict(state_dict)
|
| 81 |
+
model.eval() # important!
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
transport = create_transport(
|
| 85 |
+
args.path_type,
|
| 86 |
+
args.prediction,
|
| 87 |
+
args.loss_weight,
|
| 88 |
+
args.train_eps,
|
| 89 |
+
args.sample_eps
|
| 90 |
+
)
|
| 91 |
+
sampler = Sampler(transport)
|
| 92 |
+
if mode == "ODE":
|
| 93 |
+
if args.likelihood:
|
| 94 |
+
assert args.cfg_scale == 1, "Likelihood is incompatible with guidance"
|
| 95 |
+
sample_fn = sampler.sample_ode_likelihood(
|
| 96 |
+
sampling_method=args.sampling_method,
|
| 97 |
+
num_steps=args.num_sampling_steps,
|
| 98 |
+
atol=args.atol,
|
| 99 |
+
rtol=args.rtol,
|
| 100 |
+
)
|
| 101 |
+
else:
|
| 102 |
+
sample_fn = sampler.sample_ode(
|
| 103 |
+
sampling_method=args.sampling_method,
|
| 104 |
+
num_steps=args.num_sampling_steps,
|
| 105 |
+
atol=args.atol,
|
| 106 |
+
rtol=args.rtol,
|
| 107 |
+
reverse=args.reverse
|
| 108 |
+
)
|
| 109 |
+
elif mode == "SDE":
|
| 110 |
+
sample_fn = sampler.sample_sde(
|
| 111 |
+
sampling_method=args.sampling_method,
|
| 112 |
+
diffusion_form=args.diffusion_form,
|
| 113 |
+
diffusion_norm=args.diffusion_norm,
|
| 114 |
+
last_step=args.last_step,
|
| 115 |
+
last_step_size=args.last_step_size,
|
| 116 |
+
num_steps=args.num_sampling_steps,
|
| 117 |
+
)
|
| 118 |
+
vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device)
|
| 119 |
+
assert args.cfg_scale >= 1.0, "In almost all cases, cfg_scale be >= 1.0"
|
| 120 |
+
using_cfg = args.cfg_scale > 1.0
|
| 121 |
+
|
| 122 |
+
# Create folder to save samples:
|
| 123 |
+
model_string_name = args.model.replace("/", "-")
|
| 124 |
+
ckpt_string_name = os.path.basename(args.ckpt).replace(".pt", "") if args.ckpt else "pretrained"
|
| 125 |
+
if mode == "ODE":
|
| 126 |
+
folder_name = f"{model_string_name}-{ckpt_string_name}-" \
|
| 127 |
+
f"cfg-{args.cfg_scale}-{args.per_proc_batch_size}-"\
|
| 128 |
+
f"{mode}-{args.num_sampling_steps}-{args.sampling_method}"
|
| 129 |
+
elif mode == "SDE":
|
| 130 |
+
folder_name = f"{model_string_name}-{ckpt_string_name}-" \
|
| 131 |
+
f"cfg-{args.cfg_scale}-{args.per_proc_batch_size}-"\
|
| 132 |
+
f"{mode}-{args.num_sampling_steps}-{args.sampling_method}-"\
|
| 133 |
+
f"{args.diffusion_form}-{args.last_step}-{args.last_step_size}"
|
| 134 |
+
sample_folder_dir = f"{args.sample_dir}/{folder_name}"
|
| 135 |
+
if rank == 0:
|
| 136 |
+
os.makedirs(sample_folder_dir, exist_ok=True)
|
| 137 |
+
print(f"Saving .png samples at {sample_folder_dir}")
|
| 138 |
+
dist.barrier()
|
| 139 |
+
|
| 140 |
+
# Figure out how many samples we need to generate on each GPU and how many iterations we need to run:
|
| 141 |
+
n = args.per_proc_batch_size
|
| 142 |
+
global_batch_size = n * dist.get_world_size()
|
| 143 |
+
# To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples:
|
| 144 |
+
num_samples = len([name for name in os.listdir(sample_folder_dir) if (os.path.isfile(os.path.join(sample_folder_dir, name)) and ".png" in name)])
|
| 145 |
+
total_samples = int(math.ceil(args.num_fid_samples / global_batch_size) * global_batch_size)
|
| 146 |
+
if rank == 0:
|
| 147 |
+
print(f"Total number of images that will be sampled: {total_samples}")
|
| 148 |
+
assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size"
|
| 149 |
+
samples_needed_this_gpu = int(total_samples // dist.get_world_size())
|
| 150 |
+
assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size"
|
| 151 |
+
iterations = int(samples_needed_this_gpu // n)
|
| 152 |
+
done_iterations = int( int(num_samples // dist.get_world_size()) // n)
|
| 153 |
+
pbar = range(iterations)
|
| 154 |
+
pbar = tqdm(pbar) if rank == 0 else pbar
|
| 155 |
+
total = 0
|
| 156 |
+
|
| 157 |
+
for i in pbar:
|
| 158 |
+
# Sample inputs:
|
| 159 |
+
z = torch.randn(n, model.in_channels, latent_size, latent_size, device=device)
|
| 160 |
+
y = torch.randint(0, args.num_classes, (n,), device=device)
|
| 161 |
+
|
| 162 |
+
# Setup classifier-free guidance:
|
| 163 |
+
if using_cfg:
|
| 164 |
+
z = torch.cat([z, z], 0)
|
| 165 |
+
y_null = torch.tensor([1000] * n, device=device)
|
| 166 |
+
y = torch.cat([y, y_null], 0)
|
| 167 |
+
model_kwargs = dict(y=y, cfg_scale=args.cfg_scale)
|
| 168 |
+
model_fn = model.forward_with_cfg
|
| 169 |
+
else:
|
| 170 |
+
model_kwargs = dict(y=y)
|
| 171 |
+
model_fn = model.forward
|
| 172 |
+
|
| 173 |
+
samples = sample_fn(z, model_fn, **model_kwargs)[-1]
|
| 174 |
+
if using_cfg:
|
| 175 |
+
samples, _ = samples.chunk(2, dim=0) # Remove null class samples
|
| 176 |
+
|
| 177 |
+
samples = vae.decode(samples / 0.18215).sample
|
| 178 |
+
samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
|
| 179 |
+
|
| 180 |
+
# Save samples to disk as individual .png files
|
| 181 |
+
for i, sample in enumerate(samples):
|
| 182 |
+
index = i * dist.get_world_size() + rank + total
|
| 183 |
+
Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png")
|
| 184 |
+
total += global_batch_size
|
| 185 |
+
dist.barrier()
|
| 186 |
+
|
| 187 |
+
# Make sure all processes have finished saving their samples before attempting to convert to .npz
|
| 188 |
+
dist.barrier()
|
| 189 |
+
if rank == 0:
|
| 190 |
+
create_npz_from_sample_folder(sample_folder_dir, args.num_fid_samples)
|
| 191 |
+
print("Done.")
|
| 192 |
+
dist.barrier()
|
| 193 |
+
dist.destroy_process_group()
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
if __name__ == "__main__":
|
| 197 |
+
|
| 198 |
+
parser = argparse.ArgumentParser()
|
| 199 |
+
|
| 200 |
+
if len(sys.argv) < 2:
|
| 201 |
+
print("Usage: program.py <mode> [options]")
|
| 202 |
+
sys.exit(1)
|
| 203 |
+
|
| 204 |
+
mode = sys.argv[1]
|
| 205 |
+
|
| 206 |
+
assert mode[:2] != "--", "Usage: program.py <mode> [options]"
|
| 207 |
+
assert mode in ["ODE", "SDE"], "Invalid mode. Please choose 'ODE' or 'SDE'"
|
| 208 |
+
|
| 209 |
+
parser.add_argument("--model", type=str, choices=list(SiT_models.keys()), default="SiT-XL/2")
|
| 210 |
+
parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema")
|
| 211 |
+
parser.add_argument("--sample-dir", type=str, default="samples")
|
| 212 |
+
parser.add_argument("--per-proc-batch-size", type=int, default=4)
|
| 213 |
+
parser.add_argument("--num-fid-samples", type=int, default=50_000)
|
| 214 |
+
parser.add_argument("--image-size", type=int, choices=[256, 512], default=256)
|
| 215 |
+
parser.add_argument("--num-classes", type=int, default=1000)
|
| 216 |
+
parser.add_argument("--cfg-scale", type=float, default=1.0)
|
| 217 |
+
parser.add_argument("--num-sampling-steps", type=int, default=250)
|
| 218 |
+
parser.add_argument("--global-seed", type=int, default=0)
|
| 219 |
+
parser.add_argument("--tf32", action=argparse.BooleanOptionalAction, default=True,
|
| 220 |
+
help="By default, use TF32 matmuls. This massively accelerates sampling on Ampere GPUs.")
|
| 221 |
+
parser.add_argument("--ckpt", type=str, default=None,
|
| 222 |
+
help="Optional path to a SiT checkpoint (default: auto-download a pre-trained SiT-XL/2 model).")
|
| 223 |
+
|
| 224 |
+
parse_transport_args(parser)
|
| 225 |
+
if mode == "ODE":
|
| 226 |
+
parse_ode_args(parser)
|
| 227 |
+
# Further processing for ODE
|
| 228 |
+
elif mode == "SDE":
|
| 229 |
+
parse_sde_args(parser)
|
| 230 |
+
# Further processing for SDE
|
| 231 |
+
|
| 232 |
+
args = parser.parse_known_args()[0]
|
| 233 |
+
main(mode, args)
|
SiT_back/SiT_clean/train.py
ADDED
|
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This source code is licensed under the license found in the
|
| 2 |
+
# LICENSE file in the root directory of this source tree.
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
A minimal training script for SiT using PyTorch DDP.
|
| 6 |
+
"""
|
| 7 |
+
import torch
|
| 8 |
+
# the first flag below was False when we tested this script but True makes A100 training a lot faster:
|
| 9 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 10 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 11 |
+
import torch.distributed as dist
|
| 12 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 13 |
+
from torch.utils.data import DataLoader
|
| 14 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 15 |
+
from torchvision.datasets import ImageFolder
|
| 16 |
+
from torchvision import transforms
|
| 17 |
+
import numpy as np
|
| 18 |
+
from collections import OrderedDict
|
| 19 |
+
from PIL import Image
|
| 20 |
+
from copy import deepcopy
|
| 21 |
+
from glob import glob
|
| 22 |
+
from time import time
|
| 23 |
+
import argparse
|
| 24 |
+
import logging
|
| 25 |
+
import os
|
| 26 |
+
|
| 27 |
+
from models import SiT_models
|
| 28 |
+
from download import find_model
|
| 29 |
+
from transport import create_transport, Sampler
|
| 30 |
+
from diffusers.models import AutoencoderKL
|
| 31 |
+
from train_utils import parse_transport_args
|
| 32 |
+
|
| 33 |
+
#################################################################################
|
| 34 |
+
# Training Helper Functions #
|
| 35 |
+
#################################################################################
|
| 36 |
+
|
| 37 |
+
@torch.no_grad()
|
| 38 |
+
def update_ema(ema_model, model, decay=0.9999):
|
| 39 |
+
"""
|
| 40 |
+
Step the EMA model towards the current model.
|
| 41 |
+
"""
|
| 42 |
+
ema_params = OrderedDict(ema_model.named_parameters())
|
| 43 |
+
model_params = OrderedDict(model.named_parameters())
|
| 44 |
+
|
| 45 |
+
for name, param in model_params.items():
|
| 46 |
+
# TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
|
| 47 |
+
ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def requires_grad(model, flag=True):
|
| 51 |
+
"""
|
| 52 |
+
Set requires_grad flag for all parameters in a model.
|
| 53 |
+
"""
|
| 54 |
+
for p in model.parameters():
|
| 55 |
+
p.requires_grad = flag
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def cleanup():
|
| 59 |
+
"""
|
| 60 |
+
End DDP training.
|
| 61 |
+
"""
|
| 62 |
+
dist.destroy_process_group()
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def create_logger(logging_dir):
|
| 66 |
+
"""
|
| 67 |
+
Create a logger that writes to a log file and stdout.
|
| 68 |
+
"""
|
| 69 |
+
if dist.get_rank() == 0: # real logger
|
| 70 |
+
logging.basicConfig(
|
| 71 |
+
level=logging.INFO,
|
| 72 |
+
format='[\033[34m%(asctime)s\033[0m] %(message)s',
|
| 73 |
+
datefmt='%Y-%m-%d %H:%M:%S',
|
| 74 |
+
handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
|
| 75 |
+
)
|
| 76 |
+
logger = logging.getLogger(__name__)
|
| 77 |
+
else: # dummy logger (does nothing)
|
| 78 |
+
logger = logging.getLogger(__name__)
|
| 79 |
+
logger.addHandler(logging.NullHandler())
|
| 80 |
+
return logger
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def center_crop_arr(pil_image, image_size):
|
| 84 |
+
"""
|
| 85 |
+
Center cropping implementation from ADM.
|
| 86 |
+
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
|
| 87 |
+
"""
|
| 88 |
+
while min(*pil_image.size) >= 2 * image_size:
|
| 89 |
+
pil_image = pil_image.resize(
|
| 90 |
+
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
scale = image_size / min(*pil_image.size)
|
| 94 |
+
pil_image = pil_image.resize(
|
| 95 |
+
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
arr = np.array(pil_image)
|
| 99 |
+
crop_y = (arr.shape[0] - image_size) // 2
|
| 100 |
+
crop_x = (arr.shape[1] - image_size) // 2
|
| 101 |
+
return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
#################################################################################
|
| 105 |
+
# Training Loop #
|
| 106 |
+
#################################################################################
|
| 107 |
+
|
| 108 |
+
def main(args):
|
| 109 |
+
"""
|
| 110 |
+
Trains a new SiT model.
|
| 111 |
+
"""
|
| 112 |
+
assert torch.cuda.is_available(), "Training currently requires at least one GPU."
|
| 113 |
+
|
| 114 |
+
# Setup DDP:
|
| 115 |
+
dist.init_process_group("nccl")
|
| 116 |
+
assert args.global_batch_size % dist.get_world_size() == 0, f"Batch size must be divisible by world size."
|
| 117 |
+
rank = dist.get_rank()
|
| 118 |
+
device = rank % torch.cuda.device_count()
|
| 119 |
+
seed = args.global_seed * dist.get_world_size() + rank
|
| 120 |
+
torch.manual_seed(seed)
|
| 121 |
+
torch.cuda.set_device(device)
|
| 122 |
+
print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
|
| 123 |
+
local_batch_size = int(args.global_batch_size // dist.get_world_size())
|
| 124 |
+
|
| 125 |
+
# Setup an experiment folder:
|
| 126 |
+
if rank == 0:
|
| 127 |
+
os.makedirs(args.results_dir, exist_ok=True) # Make results folder (holds all experiment subfolders)
|
| 128 |
+
experiment_index = len(glob(f"{args.results_dir}/*"))
|
| 129 |
+
model_string_name = args.model.replace("/", "-") # e.g., SiT-XL/2 --> SiT-XL-2 (for naming folders)
|
| 130 |
+
experiment_name = f"{experiment_index:03d}-{model_string_name}-" \
|
| 131 |
+
f"{args.path_type}-{args.prediction}-{args.loss_weight}"
|
| 132 |
+
experiment_dir = f"{args.results_dir}/{experiment_name}" # Create an experiment folder
|
| 133 |
+
checkpoint_dir = f"{experiment_dir}/checkpoints" # Stores saved model checkpoints
|
| 134 |
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
| 135 |
+
|
| 136 |
+
# Create pic directory for saving sample images
|
| 137 |
+
pic_dir = f"{experiment_dir}/pic"
|
| 138 |
+
os.makedirs(pic_dir, exist_ok=True)
|
| 139 |
+
|
| 140 |
+
logger = create_logger(experiment_dir)
|
| 141 |
+
logger.info(f"Experiment directory created at {experiment_dir}")
|
| 142 |
+
logger.info(f"Sample images will be saved to {pic_dir}")
|
| 143 |
+
|
| 144 |
+
else:
|
| 145 |
+
logger = create_logger(None)
|
| 146 |
+
|
| 147 |
+
# Create model:
|
| 148 |
+
assert args.image_size % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)."
|
| 149 |
+
latent_size = args.image_size // 8
|
| 150 |
+
model = SiT_models[args.model](
|
| 151 |
+
input_size=latent_size,
|
| 152 |
+
num_classes=args.num_classes
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
# Note that parameter initialization is done within the SiT constructor
|
| 156 |
+
ema = deepcopy(model).to(device) # Create an EMA of the model for use after training
|
| 157 |
+
|
| 158 |
+
if args.ckpt is not None:
|
| 159 |
+
ckpt_path = args.ckpt
|
| 160 |
+
state_dict = find_model(ckpt_path)
|
| 161 |
+
model.load_state_dict(state_dict["model"])
|
| 162 |
+
ema.load_state_dict(state_dict["ema"])
|
| 163 |
+
opt.load_state_dict(state_dict["opt"])
|
| 164 |
+
args = state_dict["args"]
|
| 165 |
+
|
| 166 |
+
requires_grad(ema, False)
|
| 167 |
+
|
| 168 |
+
model = DDP(model.to(device), device_ids=[device])
|
| 169 |
+
transport = create_transport(
|
| 170 |
+
args.path_type,
|
| 171 |
+
args.prediction,
|
| 172 |
+
args.loss_weight,
|
| 173 |
+
args.train_eps,
|
| 174 |
+
args.sample_eps
|
| 175 |
+
) # default: velocity;
|
| 176 |
+
transport_sampler = Sampler(transport)
|
| 177 |
+
vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device)
|
| 178 |
+
logger.info(f"SiT Parameters: {sum(p.numel() for p in model.parameters()):,}")
|
| 179 |
+
|
| 180 |
+
# Setup optimizer (we used default Adam betas=(0.9, 0.999) and a constant learning rate of 1e-4 in our paper):
|
| 181 |
+
opt = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0)
|
| 182 |
+
|
| 183 |
+
# Setup data:
|
| 184 |
+
transform = transforms.Compose([
|
| 185 |
+
transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)),
|
| 186 |
+
transforms.RandomHorizontalFlip(),
|
| 187 |
+
transforms.ToTensor(),
|
| 188 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
|
| 189 |
+
])
|
| 190 |
+
dataset = ImageFolder(args.data_path, transform=transform)
|
| 191 |
+
sampler = DistributedSampler(
|
| 192 |
+
dataset,
|
| 193 |
+
num_replicas=dist.get_world_size(),
|
| 194 |
+
rank=rank,
|
| 195 |
+
shuffle=True,
|
| 196 |
+
seed=args.global_seed
|
| 197 |
+
)
|
| 198 |
+
loader = DataLoader(
|
| 199 |
+
dataset,
|
| 200 |
+
batch_size=local_batch_size,
|
| 201 |
+
shuffle=False,
|
| 202 |
+
sampler=sampler,
|
| 203 |
+
num_workers=args.num_workers,
|
| 204 |
+
pin_memory=True,
|
| 205 |
+
drop_last=True
|
| 206 |
+
)
|
| 207 |
+
logger.info(f"Dataset contains {len(dataset):,} images ({args.data_path})")
|
| 208 |
+
|
| 209 |
+
# Prepare models for training:
|
| 210 |
+
update_ema(ema, model.module, decay=0) # Ensure EMA is initialized with synced weights
|
| 211 |
+
model.train() # important! This enables embedding dropout for classifier-free guidance
|
| 212 |
+
ema.eval() # EMA model should always be in eval mode
|
| 213 |
+
|
| 214 |
+
# Variables for monitoring/logging purposes:
|
| 215 |
+
train_steps = 0
|
| 216 |
+
log_steps = 0
|
| 217 |
+
running_loss = 0
|
| 218 |
+
start_time = time()
|
| 219 |
+
|
| 220 |
+
# Labels to condition the model with (feel free to change):
|
| 221 |
+
ys = torch.randint(1000, size=(local_batch_size,), device=device)
|
| 222 |
+
use_cfg = args.cfg_scale > 1.0
|
| 223 |
+
# Create sampling noise:
|
| 224 |
+
n = ys.size(0)
|
| 225 |
+
zs = torch.randn(n, 4, latent_size, latent_size, device=device)
|
| 226 |
+
|
| 227 |
+
# Create fixed sampling noise and conditions for consistent sampling visualization
|
| 228 |
+
fixed_ys = torch.randint(1000, size=(16,), device=device) # Fixed labels for sampling
|
| 229 |
+
fixed_zs = torch.randn(16, 4, latent_size, latent_size, device=device) # Fixed noise for sampling
|
| 230 |
+
|
| 231 |
+
# Setup classifier-free guidance:
|
| 232 |
+
if use_cfg:
|
| 233 |
+
zs = torch.cat([zs, zs], 0)
|
| 234 |
+
y_null = torch.tensor([1000] * n, device=device)
|
| 235 |
+
ys = torch.cat([ys, y_null], 0)
|
| 236 |
+
sample_model_kwargs = dict(y=ys, cfg_scale=args.cfg_scale)
|
| 237 |
+
model_fn = ema.forward_with_cfg
|
| 238 |
+
else:
|
| 239 |
+
sample_model_kwargs = dict(y=ys)
|
| 240 |
+
model_fn = ema.forward
|
| 241 |
+
|
| 242 |
+
# Setup fixed classifier-free guidance for sampling:
|
| 243 |
+
if args.cfg_scale > 1.0:
|
| 244 |
+
fixed_zs = torch.cat([fixed_zs, fixed_zs], 0)
|
| 245 |
+
fixed_y_null = torch.tensor([1000] * 16, device=device)
|
| 246 |
+
fixed_ys = torch.cat([fixed_ys, fixed_y_null], 0)
|
| 247 |
+
fixed_sample_model_kwargs = dict(y=fixed_ys, cfg_scale=args.cfg_scale)
|
| 248 |
+
else:
|
| 249 |
+
fixed_sample_model_kwargs = dict(y=fixed_ys)
|
| 250 |
+
|
| 251 |
+
logger.info(f"Training for {args.epochs} epochs...")
|
| 252 |
+
for epoch in range(args.epochs):
|
| 253 |
+
sampler.set_epoch(epoch)
|
| 254 |
+
logger.info(f"Beginning epoch {epoch}...")
|
| 255 |
+
for x, y in loader:
|
| 256 |
+
x = x.to(device)
|
| 257 |
+
y = y.to(device)
|
| 258 |
+
with torch.no_grad():
|
| 259 |
+
# Map input images to latent space + normalize latents:
|
| 260 |
+
x = vae.encode(x).latent_dist.sample().mul_(0.18215)
|
| 261 |
+
model_kwargs = dict(y=y)
|
| 262 |
+
loss_dict = transport.training_losses(model, x, model_kwargs)
|
| 263 |
+
loss = loss_dict["loss"].mean()
|
| 264 |
+
opt.zero_grad()
|
| 265 |
+
loss.backward()
|
| 266 |
+
opt.step()
|
| 267 |
+
update_ema(ema, model.module)
|
| 268 |
+
|
| 269 |
+
# Log loss values:
|
| 270 |
+
running_loss += loss.item()
|
| 271 |
+
log_steps += 1
|
| 272 |
+
train_steps += 1
|
| 273 |
+
if train_steps % args.log_every == 0:
|
| 274 |
+
# Measure training speed:
|
| 275 |
+
torch.cuda.synchronize()
|
| 276 |
+
end_time = time()
|
| 277 |
+
steps_per_sec = log_steps / (end_time - start_time)
|
| 278 |
+
# Reduce loss history over all processes:
|
| 279 |
+
avg_loss = torch.tensor(running_loss / log_steps, device=device)
|
| 280 |
+
dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM)
|
| 281 |
+
avg_loss = avg_loss.item() / dist.get_world_size()
|
| 282 |
+
logger.info(f"(step={train_steps:07d}) Train Loss: {avg_loss:.4f}, Train Steps/Sec: {steps_per_sec:.2f}")
|
| 283 |
+
# Reset monitoring variables:
|
| 284 |
+
running_loss = 0
|
| 285 |
+
log_steps = 0
|
| 286 |
+
start_time = time()
|
| 287 |
+
|
| 288 |
+
# Save SiT checkpoint:
|
| 289 |
+
if train_steps % args.ckpt_every == 0 and train_steps > 0:
|
| 290 |
+
if rank == 0:
|
| 291 |
+
checkpoint = {
|
| 292 |
+
"model": model.module.state_dict(),
|
| 293 |
+
"ema": ema.state_dict(),
|
| 294 |
+
"opt": opt.state_dict(),
|
| 295 |
+
"args": args
|
| 296 |
+
}
|
| 297 |
+
checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pt"
|
| 298 |
+
torch.save(checkpoint, checkpoint_path)
|
| 299 |
+
logger.info(f"Saved checkpoint to {checkpoint_path}")
|
| 300 |
+
dist.barrier()
|
| 301 |
+
|
| 302 |
+
# Save sample images:
|
| 303 |
+
if train_steps % args.sample_every == 0 and train_steps > 0:
|
| 304 |
+
logger.info("Generating EMA samples...")
|
| 305 |
+
sample_fn = transport_sampler.sample_ode() # default to ode sampling
|
| 306 |
+
samples = sample_fn(fixed_zs, model_fn, **fixed_sample_model_kwargs)[-1]
|
| 307 |
+
dist.barrier()
|
| 308 |
+
|
| 309 |
+
if args.cfg_scale > 1.0: #remove null samples
|
| 310 |
+
samples, _ = samples.chunk(2, dim=0)
|
| 311 |
+
samples = vae.decode(samples / 0.18215).sample
|
| 312 |
+
|
| 313 |
+
# Save sample images to pic directory instead of wandb
|
| 314 |
+
if rank == 0:
|
| 315 |
+
# Create a 4x4 grid of images
|
| 316 |
+
# Normalize images from [-1, 1] to [0, 1]
|
| 317 |
+
samples = (samples.clamp(-1, 1) + 1) / 2
|
| 318 |
+
# Convert to PIL Images and arrange in a 4x4 grid
|
| 319 |
+
# Create a blank image for the grid
|
| 320 |
+
grid_size = args.image_size
|
| 321 |
+
grid_image = Image.new('RGB', (4 * grid_size, 4 * grid_size))
|
| 322 |
+
|
| 323 |
+
# Place each sample in the grid
|
| 324 |
+
for i in range(min(16, samples.shape[0])):
|
| 325 |
+
# Convert to PIL Image
|
| 326 |
+
img = samples[i].permute(1, 2, 0).cpu().detach().numpy()
|
| 327 |
+
img = (img * 255).astype(np.uint8)
|
| 328 |
+
pil_img = Image.fromarray(img)
|
| 329 |
+
|
| 330 |
+
# Calculate position in the grid
|
| 331 |
+
row = i // 4
|
| 332 |
+
col = i % 4
|
| 333 |
+
grid_image.paste(pil_img, (col * grid_size, row * grid_size))
|
| 334 |
+
|
| 335 |
+
# Save the grid image
|
| 336 |
+
img_path = f"{pic_dir}/step_{train_steps:07d}_samples_grid.png"
|
| 337 |
+
grid_image.save(img_path)
|
| 338 |
+
logger.info(f"Saved sample images grid to {img_path}")
|
| 339 |
+
|
| 340 |
+
logging.info("Generating EMA samples done.")
|
| 341 |
+
|
| 342 |
+
model.eval() # important! This disables randomized embedding dropout
|
| 343 |
+
# do any sampling/FID calculation/etc. with ema (or model) in eval mode ...
|
| 344 |
+
|
| 345 |
+
logger.info("Done!")
|
| 346 |
+
cleanup()
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
if __name__ == "__main__":
|
| 350 |
+
# Default args here will train SiT-XL/2 with the hyperparameters we used in our paper (except training iters).
|
| 351 |
+
parser = argparse.ArgumentParser()
|
| 352 |
+
parser.add_argument("--data-path", type=str, default="/gemini/platform/public/hzh/datasets/Imagenet/train/")
|
| 353 |
+
parser.add_argument("--results-dir", type=str, default="results")
|
| 354 |
+
parser.add_argument("--model", type=str, choices=list(SiT_models.keys()), default="SiT-XL/2")
|
| 355 |
+
parser.add_argument("--image-size", type=int, choices=[256, 512], default=256)
|
| 356 |
+
parser.add_argument("--num-classes", type=int, default=1000)
|
| 357 |
+
parser.add_argument("--epochs", type=int, default=140000)
|
| 358 |
+
parser.add_argument("--global-batch-size", type=int, default=256)
|
| 359 |
+
parser.add_argument("--global-seed", type=int, default=0)
|
| 360 |
+
parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema") # Choice doesn't affect training
|
| 361 |
+
parser.add_argument("--num-workers", type=int, default=4)
|
| 362 |
+
parser.add_argument("--log-every", type=int, default=100)
|
| 363 |
+
parser.add_argument("--ckpt-every", type=int, default=10)
|
| 364 |
+
parser.add_argument("--sample-every", type=int, default=10)
|
| 365 |
+
parser.add_argument("--cfg-scale", type=float, default=4.0)
|
| 366 |
+
parser.add_argument("--ckpt", type=str, default=None,
|
| 367 |
+
help="Optional path to a custom SiT checkpoint")
|
| 368 |
+
|
| 369 |
+
parse_transport_args(parser)
|
| 370 |
+
args = parser.parse_args()
|
| 371 |
+
main(args)
|
SiT_back/SiT_clean/train_utils.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def none_or_str(value):
|
| 2 |
+
if value == 'None':
|
| 3 |
+
return None
|
| 4 |
+
return value
|
| 5 |
+
|
| 6 |
+
def parse_transport_args(parser):
|
| 7 |
+
group = parser.add_argument_group("Transport arguments")
|
| 8 |
+
group.add_argument("--path-type", type=str, default="Linear", choices=["Linear", "GVP", "VP"])
|
| 9 |
+
group.add_argument("--prediction", type=str, default="velocity", choices=["velocity", "score", "noise"])
|
| 10 |
+
group.add_argument("--loss-weight", type=none_or_str, default=None, choices=[None, "velocity", "likelihood"])
|
| 11 |
+
group.add_argument("--sample-eps", type=float)
|
| 12 |
+
group.add_argument("--train-eps", type=float)
|
| 13 |
+
|
| 14 |
+
def parse_ode_args(parser):
|
| 15 |
+
group = parser.add_argument_group("ODE arguments")
|
| 16 |
+
group.add_argument("--sampling-method", type=str, default="dopri5", help="blackbox ODE solver methods; for full list check https://github.com/rtqichen/torchdiffeq")
|
| 17 |
+
group.add_argument("--atol", type=float, default=1e-6, help="Absolute tolerance")
|
| 18 |
+
group.add_argument("--rtol", type=float, default=1e-3, help="Relative tolerance")
|
| 19 |
+
group.add_argument("--reverse", action="store_true")
|
| 20 |
+
group.add_argument("--likelihood", action="store_true")
|
| 21 |
+
|
| 22 |
+
def parse_sde_args(parser):
|
| 23 |
+
group = parser.add_argument_group("SDE arguments")
|
| 24 |
+
group.add_argument("--sampling-method", type=str, default="Euler", choices=["Euler", "Heun"])
|
| 25 |
+
group.add_argument("--diffusion-form", type=str, default="sigma", \
|
| 26 |
+
choices=["constant", "SBDM", "sigma", "linear", "decreasing", "increasing-decreasing"],\
|
| 27 |
+
help="form of diffusion coefficient in the SDE")
|
| 28 |
+
group.add_argument("--diffusion-norm", type=float, default=1.0)
|
| 29 |
+
group.add_argument("--last-step", type=none_or_str, default="Mean", choices=[None, "Mean", "Tweedie", "Euler"],\
|
| 30 |
+
help="form of last step taken in the SDE")
|
| 31 |
+
group.add_argument("--last-step-size", type=float, default=0.04, \
|
| 32 |
+
help="size of the last step taken")
|
SiT_back/SiT_clean/transport/__init__.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .transport import Transport, ModelType, WeightType, PathType, Sampler
|
| 2 |
+
|
| 3 |
+
def create_transport(
|
| 4 |
+
path_type='Linear',
|
| 5 |
+
prediction="velocity",
|
| 6 |
+
loss_weight=None,
|
| 7 |
+
train_eps=None,
|
| 8 |
+
sample_eps=None,
|
| 9 |
+
):
|
| 10 |
+
"""function for creating Transport object
|
| 11 |
+
**Note**: model prediction defaults to velocity
|
| 12 |
+
Args:
|
| 13 |
+
- path_type: type of path to use; default to linear
|
| 14 |
+
- learn_score: set model prediction to score
|
| 15 |
+
- learn_noise: set model prediction to noise
|
| 16 |
+
- velocity_weighted: weight loss by velocity weight
|
| 17 |
+
- likelihood_weighted: weight loss by likelihood weight
|
| 18 |
+
- train_eps: small epsilon for avoiding instability during training
|
| 19 |
+
- sample_eps: small epsilon for avoiding instability during sampling
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
if prediction == "noise":
|
| 23 |
+
model_type = ModelType.NOISE
|
| 24 |
+
elif prediction == "score":
|
| 25 |
+
model_type = ModelType.SCORE
|
| 26 |
+
else:
|
| 27 |
+
model_type = ModelType.VELOCITY
|
| 28 |
+
|
| 29 |
+
if loss_weight == "velocity":
|
| 30 |
+
loss_type = WeightType.VELOCITY
|
| 31 |
+
elif loss_weight == "likelihood":
|
| 32 |
+
loss_type = WeightType.LIKELIHOOD
|
| 33 |
+
else:
|
| 34 |
+
loss_type = WeightType.NONE
|
| 35 |
+
|
| 36 |
+
path_choice = {
|
| 37 |
+
"Linear": PathType.LINEAR,
|
| 38 |
+
"GVP": PathType.GVP,
|
| 39 |
+
"VP": PathType.VP,
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
path_type = path_choice[path_type]
|
| 43 |
+
|
| 44 |
+
if (path_type in [PathType.VP]):
|
| 45 |
+
train_eps_new = 1e-5 if train_eps is None else train_eps
|
| 46 |
+
sample_eps_new = 1e-3 if train_eps is None else sample_eps
|
| 47 |
+
train_eps, sample_eps = train_eps_new, sample_eps_new
|
| 48 |
+
elif (path_type in [PathType.GVP, PathType.LINEAR] and model_type != ModelType.VELOCITY):
|
| 49 |
+
train_eps_new = 1e-3 if train_eps is None else train_eps
|
| 50 |
+
sample_eps_new = 1e-3 if train_eps is None else sample_eps
|
| 51 |
+
train_eps, sample_eps = train_eps_new, sample_eps_new
|
| 52 |
+
else: # velocity & [GVP, LINEAR] is stable everywhere
|
| 53 |
+
train_eps = 0
|
| 54 |
+
sample_eps = 0
|
| 55 |
+
|
| 56 |
+
# create flow state
|
| 57 |
+
state = Transport(
|
| 58 |
+
model_type=model_type,
|
| 59 |
+
path_type=path_type,
|
| 60 |
+
loss_type=loss_type,
|
| 61 |
+
train_eps=train_eps,
|
| 62 |
+
sample_eps=sample_eps,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
return state
|
SiT_back/SiT_clean/transport/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (2.19 kB). View file
|
|
|
SiT_back/SiT_clean/transport/__pycache__/integrators.cpython-312.pyc
ADDED
|
Binary file (6.21 kB). View file
|
|
|
SiT_back/SiT_clean/transport/__pycache__/path.cpython-312.pyc
ADDED
|
Binary file (11.3 kB). View file
|
|
|
SiT_back/SiT_clean/transport/__pycache__/transport.cpython-312.pyc
ADDED
|
Binary file (20.4 kB). View file
|
|
|
SiT_back/SiT_clean/transport/__pycache__/utils.cpython-312.pyc
ADDED
|
Binary file (1.87 kB). View file
|
|
|
SiT_back/SiT_clean/transport/integrators.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch as th
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from torchdiffeq import odeint
|
| 5 |
+
from functools import partial
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
|
| 8 |
+
class sde:
|
| 9 |
+
"""SDE solver class"""
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
drift,
|
| 13 |
+
diffusion,
|
| 14 |
+
*,
|
| 15 |
+
t0,
|
| 16 |
+
t1,
|
| 17 |
+
num_steps,
|
| 18 |
+
sampler_type,
|
| 19 |
+
):
|
| 20 |
+
assert t0 < t1, "SDE sampler has to be in forward time"
|
| 21 |
+
|
| 22 |
+
self.num_timesteps = num_steps
|
| 23 |
+
self.t = th.linspace(t0, t1, num_steps)
|
| 24 |
+
self.dt = self.t[1] - self.t[0]
|
| 25 |
+
self.drift = drift
|
| 26 |
+
self.diffusion = diffusion
|
| 27 |
+
self.sampler_type = sampler_type
|
| 28 |
+
|
| 29 |
+
def __Euler_Maruyama_step(self, x, mean_x, t, model, **model_kwargs):
|
| 30 |
+
w_cur = th.randn(x.size()).to(x)
|
| 31 |
+
t = th.ones(x.size(0)).to(x) * t
|
| 32 |
+
dw = w_cur * th.sqrt(self.dt)
|
| 33 |
+
drift = self.drift(x, t, model, **model_kwargs)
|
| 34 |
+
diffusion = self.diffusion(x, t)
|
| 35 |
+
mean_x = x + drift * self.dt
|
| 36 |
+
x = mean_x + th.sqrt(2 * diffusion) * dw
|
| 37 |
+
return x, mean_x
|
| 38 |
+
|
| 39 |
+
def __Heun_step(self, x, _, t, model, **model_kwargs):
|
| 40 |
+
w_cur = th.randn(x.size()).to(x)
|
| 41 |
+
dw = w_cur * th.sqrt(self.dt)
|
| 42 |
+
t_cur = th.ones(x.size(0)).to(x) * t
|
| 43 |
+
diffusion = self.diffusion(x, t_cur)
|
| 44 |
+
xhat = x + th.sqrt(2 * diffusion) * dw
|
| 45 |
+
K1 = self.drift(xhat, t_cur, model, **model_kwargs)
|
| 46 |
+
xp = xhat + self.dt * K1
|
| 47 |
+
K2 = self.drift(xp, t_cur + self.dt, model, **model_kwargs)
|
| 48 |
+
return xhat + 0.5 * self.dt * (K1 + K2), xhat # at last time point we do not perform the heun step
|
| 49 |
+
|
| 50 |
+
def __forward_fn(self):
|
| 51 |
+
"""TODO: generalize here by adding all private functions ending with steps to it"""
|
| 52 |
+
sampler_dict = {
|
| 53 |
+
"Euler": self.__Euler_Maruyama_step,
|
| 54 |
+
"Heun": self.__Heun_step,
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
try:
|
| 58 |
+
sampler = sampler_dict[self.sampler_type]
|
| 59 |
+
except:
|
| 60 |
+
raise NotImplementedError("Smapler type not implemented.")
|
| 61 |
+
|
| 62 |
+
return sampler
|
| 63 |
+
|
| 64 |
+
def sample(self, init, model, **model_kwargs):
|
| 65 |
+
"""forward loop of sde"""
|
| 66 |
+
x = init
|
| 67 |
+
mean_x = init
|
| 68 |
+
samples = []
|
| 69 |
+
sampler = self.__forward_fn()
|
| 70 |
+
for ti in self.t[:-1]:
|
| 71 |
+
with th.no_grad():
|
| 72 |
+
x, mean_x = sampler(x, mean_x, ti, model, **model_kwargs)
|
| 73 |
+
samples.append(x)
|
| 74 |
+
|
| 75 |
+
return samples
|
| 76 |
+
|
| 77 |
+
class ode:
|
| 78 |
+
"""ODE solver class"""
|
| 79 |
+
def __init__(
|
| 80 |
+
self,
|
| 81 |
+
drift,
|
| 82 |
+
*,
|
| 83 |
+
t0,
|
| 84 |
+
t1,
|
| 85 |
+
sampler_type,
|
| 86 |
+
num_steps,
|
| 87 |
+
atol,
|
| 88 |
+
rtol,
|
| 89 |
+
):
|
| 90 |
+
self.drift = drift
|
| 91 |
+
self.t = th.linspace(t0, t1, num_steps)
|
| 92 |
+
self.atol = atol
|
| 93 |
+
self.rtol = rtol
|
| 94 |
+
self.sampler_type = sampler_type
|
| 95 |
+
|
| 96 |
+
def sample(self, x, model, **model_kwargs):
|
| 97 |
+
|
| 98 |
+
device = x[0].device if isinstance(x, tuple) else x.device
|
| 99 |
+
def _fn(t, x):
|
| 100 |
+
t = th.ones(x[0].size(0)).to(device) * t if isinstance(x, tuple) else th.ones(x.size(0)).to(device) * t
|
| 101 |
+
model_output = self.drift(x, t, model, **model_kwargs)
|
| 102 |
+
return model_output
|
| 103 |
+
|
| 104 |
+
t = self.t.to(device)
|
| 105 |
+
atol = [self.atol] * len(x) if isinstance(x, tuple) else [self.atol]
|
| 106 |
+
rtol = [self.rtol] * len(x) if isinstance(x, tuple) else [self.rtol]
|
| 107 |
+
samples = odeint(
|
| 108 |
+
_fn,
|
| 109 |
+
x,
|
| 110 |
+
t,
|
| 111 |
+
method=self.sampler_type,
|
| 112 |
+
atol=atol,
|
| 113 |
+
rtol=rtol
|
| 114 |
+
)
|
| 115 |
+
return samples
|
SiT_back/SiT_clean/transport/path.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch as th
|
| 2 |
+
import numpy as np
|
| 3 |
+
from functools import partial
|
| 4 |
+
|
| 5 |
+
def expand_t_like_x(t, x):
|
| 6 |
+
"""Function to reshape time t to broadcastable dimension of x
|
| 7 |
+
Args:
|
| 8 |
+
t: [batch_dim,], time vector
|
| 9 |
+
x: [batch_dim,...], data point
|
| 10 |
+
"""
|
| 11 |
+
dims = [1] * (len(x.size()) - 1)
|
| 12 |
+
t = t.view(t.size(0), *dims)
|
| 13 |
+
return t
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
#################### Coupling Plans ####################
|
| 17 |
+
|
| 18 |
+
class ICPlan:
|
| 19 |
+
"""Linear Coupling Plan"""
|
| 20 |
+
def __init__(self, sigma=0.0):
|
| 21 |
+
self.sigma = sigma
|
| 22 |
+
|
| 23 |
+
def compute_alpha_t(self, t):
|
| 24 |
+
"""Compute the data coefficient along the path"""
|
| 25 |
+
return t, 1
|
| 26 |
+
|
| 27 |
+
def compute_sigma_t(self, t):
|
| 28 |
+
"""Compute the noise coefficient along the path"""
|
| 29 |
+
return 1 - t, -1
|
| 30 |
+
|
| 31 |
+
def compute_d_alpha_alpha_ratio_t(self, t):
|
| 32 |
+
"""Compute the ratio between d_alpha and alpha"""
|
| 33 |
+
return 1 / t
|
| 34 |
+
|
| 35 |
+
def compute_drift(self, x, t):
|
| 36 |
+
"""We always output sde according to score parametrization; """
|
| 37 |
+
t = expand_t_like_x(t, x)
|
| 38 |
+
alpha_ratio = self.compute_d_alpha_alpha_ratio_t(t)
|
| 39 |
+
sigma_t, d_sigma_t = self.compute_sigma_t(t)
|
| 40 |
+
drift = alpha_ratio * x
|
| 41 |
+
diffusion = alpha_ratio * (sigma_t ** 2) - sigma_t * d_sigma_t
|
| 42 |
+
|
| 43 |
+
return -drift, diffusion
|
| 44 |
+
|
| 45 |
+
def compute_diffusion(self, x, t, form="constant", norm=1.0):
|
| 46 |
+
"""Compute the diffusion term of the SDE
|
| 47 |
+
Args:
|
| 48 |
+
x: [batch_dim, ...], data point
|
| 49 |
+
t: [batch_dim,], time vector
|
| 50 |
+
form: str, form of the diffusion term
|
| 51 |
+
norm: float, norm of the diffusion term
|
| 52 |
+
"""
|
| 53 |
+
t = expand_t_like_x(t, x)
|
| 54 |
+
choices = {
|
| 55 |
+
"constant": norm,
|
| 56 |
+
"SBDM": norm * self.compute_drift(x, t)[1],
|
| 57 |
+
"sigma": norm * self.compute_sigma_t(t)[0],
|
| 58 |
+
"linear": norm * (1 - t),
|
| 59 |
+
"decreasing": 0.25 * (norm * th.cos(np.pi * t) + 1) ** 2,
|
| 60 |
+
"inccreasing-decreasing": norm * th.sin(np.pi * t) ** 2,
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
try:
|
| 64 |
+
diffusion = choices[form]
|
| 65 |
+
except KeyError:
|
| 66 |
+
raise NotImplementedError(f"Diffusion form {form} not implemented")
|
| 67 |
+
|
| 68 |
+
return diffusion
|
| 69 |
+
|
| 70 |
+
def get_score_from_velocity(self, velocity, x, t):
|
| 71 |
+
"""Wrapper function: transfrom velocity prediction model to score
|
| 72 |
+
Args:
|
| 73 |
+
velocity: [batch_dim, ...] shaped tensor; velocity model output
|
| 74 |
+
x: [batch_dim, ...] shaped tensor; x_t data point
|
| 75 |
+
t: [batch_dim,] time tensor
|
| 76 |
+
"""
|
| 77 |
+
t = expand_t_like_x(t, x)
|
| 78 |
+
alpha_t, d_alpha_t = self.compute_alpha_t(t)
|
| 79 |
+
sigma_t, d_sigma_t = self.compute_sigma_t(t)
|
| 80 |
+
mean = x
|
| 81 |
+
reverse_alpha_ratio = alpha_t / d_alpha_t
|
| 82 |
+
var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t
|
| 83 |
+
score = (reverse_alpha_ratio * velocity - mean) / var
|
| 84 |
+
return score
|
| 85 |
+
|
| 86 |
+
def get_noise_from_velocity(self, velocity, x, t):
|
| 87 |
+
"""Wrapper function: transfrom velocity prediction model to denoiser
|
| 88 |
+
Args:
|
| 89 |
+
velocity: [batch_dim, ...] shaped tensor; velocity model output
|
| 90 |
+
x: [batch_dim, ...] shaped tensor; x_t data point
|
| 91 |
+
t: [batch_dim,] time tensor
|
| 92 |
+
"""
|
| 93 |
+
t = expand_t_like_x(t, x)
|
| 94 |
+
alpha_t, d_alpha_t = self.compute_alpha_t(t)
|
| 95 |
+
sigma_t, d_sigma_t = self.compute_sigma_t(t)
|
| 96 |
+
mean = x
|
| 97 |
+
reverse_alpha_ratio = alpha_t / d_alpha_t
|
| 98 |
+
var = reverse_alpha_ratio * d_sigma_t - sigma_t
|
| 99 |
+
noise = (reverse_alpha_ratio * velocity - mean) / var
|
| 100 |
+
return noise
|
| 101 |
+
|
| 102 |
+
def get_velocity_from_score(self, score, x, t):
|
| 103 |
+
"""Wrapper function: transfrom score prediction model to velocity
|
| 104 |
+
Args:
|
| 105 |
+
score: [batch_dim, ...] shaped tensor; score model output
|
| 106 |
+
x: [batch_dim, ...] shaped tensor; x_t data point
|
| 107 |
+
t: [batch_dim,] time tensor
|
| 108 |
+
"""
|
| 109 |
+
t = expand_t_like_x(t, x)
|
| 110 |
+
drift, var = self.compute_drift(x, t)
|
| 111 |
+
velocity = var * score - drift
|
| 112 |
+
return velocity
|
| 113 |
+
|
| 114 |
+
def compute_mu_t(self, t, x0, x1):
|
| 115 |
+
"""Compute the mean of time-dependent density p_t"""
|
| 116 |
+
t = expand_t_like_x(t, x1)
|
| 117 |
+
alpha_t, _ = self.compute_alpha_t(t)
|
| 118 |
+
sigma_t, _ = self.compute_sigma_t(t)
|
| 119 |
+
return alpha_t * x1 + sigma_t * x0
|
| 120 |
+
|
| 121 |
+
def compute_xt(self, t, x0, x1):
|
| 122 |
+
"""Sample xt from time-dependent density p_t; rng is required"""
|
| 123 |
+
xt = self.compute_mu_t(t, x0, x1)
|
| 124 |
+
return xt
|
| 125 |
+
|
| 126 |
+
def compute_ut(self, t, x0, x1, xt):
|
| 127 |
+
"""Compute the vector field corresponding to p_t"""
|
| 128 |
+
t = expand_t_like_x(t, x1)
|
| 129 |
+
_, d_alpha_t = self.compute_alpha_t(t)
|
| 130 |
+
_, d_sigma_t = self.compute_sigma_t(t)
|
| 131 |
+
return d_alpha_t * x1 + d_sigma_t * x0
|
| 132 |
+
|
| 133 |
+
def plan(self, t, x0, x1):
|
| 134 |
+
xt = self.compute_xt(t, x0, x1)
|
| 135 |
+
ut = self.compute_ut(t, x0, x1, xt)
|
| 136 |
+
return t, xt, ut
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class VPCPlan(ICPlan):
|
| 140 |
+
"""class for VP path flow matching"""
|
| 141 |
+
|
| 142 |
+
def __init__(self, sigma_min=0.1, sigma_max=20.0):
|
| 143 |
+
self.sigma_min = sigma_min
|
| 144 |
+
self.sigma_max = sigma_max
|
| 145 |
+
self.log_mean_coeff = lambda t: -0.25 * ((1 - t) ** 2) * (self.sigma_max - self.sigma_min) - 0.5 * (1 - t) * self.sigma_min
|
| 146 |
+
self.d_log_mean_coeff = lambda t: 0.5 * (1 - t) * (self.sigma_max - self.sigma_min) + 0.5 * self.sigma_min
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def compute_alpha_t(self, t):
|
| 150 |
+
"""Compute coefficient of x1"""
|
| 151 |
+
alpha_t = self.log_mean_coeff(t)
|
| 152 |
+
alpha_t = th.exp(alpha_t)
|
| 153 |
+
d_alpha_t = alpha_t * self.d_log_mean_coeff(t)
|
| 154 |
+
return alpha_t, d_alpha_t
|
| 155 |
+
|
| 156 |
+
def compute_sigma_t(self, t):
|
| 157 |
+
"""Compute coefficient of x0"""
|
| 158 |
+
p_sigma_t = 2 * self.log_mean_coeff(t)
|
| 159 |
+
sigma_t = th.sqrt(1 - th.exp(p_sigma_t))
|
| 160 |
+
d_sigma_t = th.exp(p_sigma_t) * (2 * self.d_log_mean_coeff(t)) / (-2 * sigma_t)
|
| 161 |
+
return sigma_t, d_sigma_t
|
| 162 |
+
|
| 163 |
+
def compute_d_alpha_alpha_ratio_t(self, t):
|
| 164 |
+
"""Special purposed function for computing numerical stabled d_alpha_t / alpha_t"""
|
| 165 |
+
return self.d_log_mean_coeff(t)
|
| 166 |
+
|
| 167 |
+
def compute_drift(self, x, t):
|
| 168 |
+
"""Compute the drift term of the SDE"""
|
| 169 |
+
t = expand_t_like_x(t, x)
|
| 170 |
+
beta_t = self.sigma_min + (1 - t) * (self.sigma_max - self.sigma_min)
|
| 171 |
+
return -0.5 * beta_t * x, beta_t / 2
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
class GVPCPlan(ICPlan):
|
| 175 |
+
def __init__(self, sigma=0.0):
|
| 176 |
+
super().__init__(sigma)
|
| 177 |
+
|
| 178 |
+
def compute_alpha_t(self, t):
|
| 179 |
+
"""Compute coefficient of x1"""
|
| 180 |
+
alpha_t = th.sin(t * np.pi / 2)
|
| 181 |
+
d_alpha_t = np.pi / 2 * th.cos(t * np.pi / 2)
|
| 182 |
+
return alpha_t, d_alpha_t
|
| 183 |
+
|
| 184 |
+
def compute_sigma_t(self, t):
|
| 185 |
+
"""Compute coefficient of x0"""
|
| 186 |
+
sigma_t = th.cos(t * np.pi / 2)
|
| 187 |
+
d_sigma_t = -np.pi / 2 * th.sin(t * np.pi / 2)
|
| 188 |
+
return sigma_t, d_sigma_t
|
| 189 |
+
|
| 190 |
+
def compute_d_alpha_alpha_ratio_t(self, t):
|
| 191 |
+
"""Special purposed function for computing numerical stabled d_alpha_t / alpha_t"""
|
| 192 |
+
return np.pi / (2 * th.tan(t * np.pi / 2))
|
SiT_back/SiT_clean/transport/transport.py
ADDED
|
@@ -0,0 +1,440 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch as th
|
| 2 |
+
import numpy as np
|
| 3 |
+
import logging
|
| 4 |
+
|
| 5 |
+
import enum
|
| 6 |
+
|
| 7 |
+
from . import path
|
| 8 |
+
from .utils import EasyDict, log_state, mean_flat
|
| 9 |
+
from .integrators import ode, sde
|
| 10 |
+
|
| 11 |
+
class ModelType(enum.Enum):
|
| 12 |
+
"""
|
| 13 |
+
Which type of output the model predicts.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
NOISE = enum.auto() # the model predicts epsilon
|
| 17 |
+
SCORE = enum.auto() # the model predicts \nabla \log p(x)
|
| 18 |
+
VELOCITY = enum.auto() # the model predicts v(x)
|
| 19 |
+
|
| 20 |
+
class PathType(enum.Enum):
|
| 21 |
+
"""
|
| 22 |
+
Which type of path to use.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
LINEAR = enum.auto()
|
| 26 |
+
GVP = enum.auto()
|
| 27 |
+
VP = enum.auto()
|
| 28 |
+
|
| 29 |
+
class WeightType(enum.Enum):
|
| 30 |
+
"""
|
| 31 |
+
Which type of weighting to use.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
NONE = enum.auto()
|
| 35 |
+
VELOCITY = enum.auto()
|
| 36 |
+
LIKELIHOOD = enum.auto()
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class Transport:
|
| 40 |
+
|
| 41 |
+
def __init__(
|
| 42 |
+
self,
|
| 43 |
+
*,
|
| 44 |
+
model_type,
|
| 45 |
+
path_type,
|
| 46 |
+
loss_type,
|
| 47 |
+
train_eps,
|
| 48 |
+
sample_eps,
|
| 49 |
+
):
|
| 50 |
+
path_options = {
|
| 51 |
+
PathType.LINEAR: path.ICPlan,
|
| 52 |
+
PathType.GVP: path.GVPCPlan,
|
| 53 |
+
PathType.VP: path.VPCPlan,
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
self.loss_type = loss_type
|
| 57 |
+
self.model_type = model_type
|
| 58 |
+
self.path_sampler = path_options[path_type]()
|
| 59 |
+
self.train_eps = train_eps
|
| 60 |
+
self.sample_eps = sample_eps
|
| 61 |
+
|
| 62 |
+
def prior_logp(self, z):
|
| 63 |
+
'''
|
| 64 |
+
Standard multivariate normal prior
|
| 65 |
+
Assume z is batched
|
| 66 |
+
'''
|
| 67 |
+
shape = th.tensor(z.size())
|
| 68 |
+
N = th.prod(shape[1:])
|
| 69 |
+
_fn = lambda x: -N / 2. * np.log(2 * np.pi) - th.sum(x ** 2) / 2.
|
| 70 |
+
return th.vmap(_fn)(z)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def check_interval(
|
| 74 |
+
self,
|
| 75 |
+
train_eps,
|
| 76 |
+
sample_eps,
|
| 77 |
+
*,
|
| 78 |
+
diffusion_form="SBDM",
|
| 79 |
+
sde=False,
|
| 80 |
+
reverse=False,
|
| 81 |
+
eval=False,
|
| 82 |
+
last_step_size=0.0,
|
| 83 |
+
):
|
| 84 |
+
t0 = 0
|
| 85 |
+
t1 = 1
|
| 86 |
+
eps = train_eps if not eval else sample_eps
|
| 87 |
+
if (type(self.path_sampler) in [path.VPCPlan]):
|
| 88 |
+
|
| 89 |
+
t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size
|
| 90 |
+
|
| 91 |
+
elif (type(self.path_sampler) in [path.ICPlan, path.GVPCPlan]) \
|
| 92 |
+
and (self.model_type != ModelType.VELOCITY or sde): # avoid numerical issue by taking a first semi-implicit step
|
| 93 |
+
|
| 94 |
+
t0 = eps if (diffusion_form == "SBDM" and sde) or self.model_type != ModelType.VELOCITY else 0
|
| 95 |
+
t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size
|
| 96 |
+
|
| 97 |
+
if reverse:
|
| 98 |
+
t0, t1 = 1 - t0, 1 - t1
|
| 99 |
+
|
| 100 |
+
return t0, t1
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def sample(self, x1):
|
| 104 |
+
"""Sampling x0 & t based on shape of x1 (if needed)
|
| 105 |
+
Args:
|
| 106 |
+
x1 - data point; [batch, *dim]
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
x0 = th.randn_like(x1)
|
| 110 |
+
t0, t1 = self.check_interval(self.train_eps, self.sample_eps)
|
| 111 |
+
t = th.rand((x1.shape[0],)) * (t1 - t0) + t0
|
| 112 |
+
t = t.to(x1)
|
| 113 |
+
return t, x0, x1
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def training_losses(
|
| 117 |
+
self,
|
| 118 |
+
model,
|
| 119 |
+
x1,
|
| 120 |
+
model_kwargs=None
|
| 121 |
+
):
|
| 122 |
+
"""Loss for training the score model
|
| 123 |
+
Args:
|
| 124 |
+
- model: backbone model; could be score, noise, or velocity
|
| 125 |
+
- x1: datapoint
|
| 126 |
+
- model_kwargs: additional arguments for the model
|
| 127 |
+
"""
|
| 128 |
+
if model_kwargs == None:
|
| 129 |
+
model_kwargs = {}
|
| 130 |
+
|
| 131 |
+
t, x0, x1 = self.sample(x1)
|
| 132 |
+
t, xt, ut = self.path_sampler.plan(t, x0, x1)
|
| 133 |
+
model_output = model(xt, t, **model_kwargs)
|
| 134 |
+
B, *_, C = xt.shape
|
| 135 |
+
assert model_output.size() == (B, *xt.size()[1:-1], C)
|
| 136 |
+
|
| 137 |
+
terms = {}
|
| 138 |
+
terms['pred'] = model_output
|
| 139 |
+
if self.model_type == ModelType.VELOCITY:
|
| 140 |
+
terms['loss'] = mean_flat(((model_output - ut) ** 2))
|
| 141 |
+
else:
|
| 142 |
+
_, drift_var = self.path_sampler.compute_drift(xt, t)
|
| 143 |
+
sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, xt))
|
| 144 |
+
if self.loss_type in [WeightType.VELOCITY]:
|
| 145 |
+
weight = (drift_var / sigma_t) ** 2
|
| 146 |
+
elif self.loss_type in [WeightType.LIKELIHOOD]:
|
| 147 |
+
weight = drift_var / (sigma_t ** 2)
|
| 148 |
+
elif self.loss_type in [WeightType.NONE]:
|
| 149 |
+
weight = 1
|
| 150 |
+
else:
|
| 151 |
+
raise NotImplementedError()
|
| 152 |
+
|
| 153 |
+
if self.model_type == ModelType.NOISE:
|
| 154 |
+
terms['loss'] = mean_flat(weight * ((model_output - x0) ** 2))
|
| 155 |
+
else:
|
| 156 |
+
terms['loss'] = mean_flat(weight * ((model_output * sigma_t + x0) ** 2))
|
| 157 |
+
|
| 158 |
+
return terms
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def get_drift(
|
| 162 |
+
self
|
| 163 |
+
):
|
| 164 |
+
"""member function for obtaining the drift of the probability flow ODE"""
|
| 165 |
+
def score_ode(x, t, model, **model_kwargs):
|
| 166 |
+
drift_mean, drift_var = self.path_sampler.compute_drift(x, t)
|
| 167 |
+
model_output = model(x, t, **model_kwargs)
|
| 168 |
+
return (-drift_mean + drift_var * model_output) # by change of variable
|
| 169 |
+
|
| 170 |
+
def noise_ode(x, t, model, **model_kwargs):
|
| 171 |
+
drift_mean, drift_var = self.path_sampler.compute_drift(x, t)
|
| 172 |
+
sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))
|
| 173 |
+
model_output = model(x, t, **model_kwargs)
|
| 174 |
+
score = model_output / -sigma_t
|
| 175 |
+
return (-drift_mean + drift_var * score)
|
| 176 |
+
|
| 177 |
+
def velocity_ode(x, t, model, **model_kwargs):
|
| 178 |
+
model_output = model(x, t, **model_kwargs)
|
| 179 |
+
return model_output
|
| 180 |
+
|
| 181 |
+
if self.model_type == ModelType.NOISE:
|
| 182 |
+
drift_fn = noise_ode
|
| 183 |
+
elif self.model_type == ModelType.SCORE:
|
| 184 |
+
drift_fn = score_ode
|
| 185 |
+
else:
|
| 186 |
+
drift_fn = velocity_ode
|
| 187 |
+
|
| 188 |
+
def body_fn(x, t, model, **model_kwargs):
|
| 189 |
+
model_output = drift_fn(x, t, model, **model_kwargs)
|
| 190 |
+
assert model_output.shape == x.shape, "Output shape from ODE solver must match input shape"
|
| 191 |
+
return model_output
|
| 192 |
+
|
| 193 |
+
return body_fn
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def get_score(
|
| 197 |
+
self,
|
| 198 |
+
):
|
| 199 |
+
"""member function for obtaining score of
|
| 200 |
+
x_t = alpha_t * x + sigma_t * eps"""
|
| 201 |
+
if self.model_type == ModelType.NOISE:
|
| 202 |
+
score_fn = lambda x, t, model, **kwargs: model(x, t, **kwargs) / -self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))[0]
|
| 203 |
+
elif self.model_type == ModelType.SCORE:
|
| 204 |
+
score_fn = lambda x, t, model, **kwagrs: model(x, t, **kwagrs)
|
| 205 |
+
elif self.model_type == ModelType.VELOCITY:
|
| 206 |
+
score_fn = lambda x, t, model, **kwargs: self.path_sampler.get_score_from_velocity(model(x, t, **kwargs), x, t)
|
| 207 |
+
else:
|
| 208 |
+
raise NotImplementedError()
|
| 209 |
+
|
| 210 |
+
return score_fn
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
class Sampler:
|
| 214 |
+
"""Sampler class for the transport model"""
|
| 215 |
+
def __init__(
|
| 216 |
+
self,
|
| 217 |
+
transport,
|
| 218 |
+
):
|
| 219 |
+
"""Constructor for a general sampler; supporting different sampling methods
|
| 220 |
+
Args:
|
| 221 |
+
- transport: an tranport object specify model prediction & interpolant type
|
| 222 |
+
"""
|
| 223 |
+
|
| 224 |
+
self.transport = transport
|
| 225 |
+
self.drift = self.transport.get_drift()
|
| 226 |
+
self.score = self.transport.get_score()
|
| 227 |
+
|
| 228 |
+
def __get_sde_diffusion_and_drift(
|
| 229 |
+
self,
|
| 230 |
+
*,
|
| 231 |
+
diffusion_form="SBDM",
|
| 232 |
+
diffusion_norm=1.0,
|
| 233 |
+
):
|
| 234 |
+
|
| 235 |
+
def diffusion_fn(x, t):
|
| 236 |
+
diffusion = self.transport.path_sampler.compute_diffusion(x, t, form=diffusion_form, norm=diffusion_norm)
|
| 237 |
+
return diffusion
|
| 238 |
+
|
| 239 |
+
sde_drift = \
|
| 240 |
+
lambda x, t, model, **kwargs: \
|
| 241 |
+
self.drift(x, t, model, **kwargs) + diffusion_fn(x, t) * self.score(x, t, model, **kwargs)
|
| 242 |
+
|
| 243 |
+
sde_diffusion = diffusion_fn
|
| 244 |
+
|
| 245 |
+
return sde_drift, sde_diffusion
|
| 246 |
+
|
| 247 |
+
def __get_last_step(
|
| 248 |
+
self,
|
| 249 |
+
sde_drift,
|
| 250 |
+
*,
|
| 251 |
+
last_step,
|
| 252 |
+
last_step_size,
|
| 253 |
+
):
|
| 254 |
+
"""Get the last step function of the SDE solver"""
|
| 255 |
+
|
| 256 |
+
if last_step is None:
|
| 257 |
+
last_step_fn = \
|
| 258 |
+
lambda x, t, model, **model_kwargs: \
|
| 259 |
+
x
|
| 260 |
+
elif last_step == "Mean":
|
| 261 |
+
last_step_fn = \
|
| 262 |
+
lambda x, t, model, **model_kwargs: \
|
| 263 |
+
x + sde_drift(x, t, model, **model_kwargs) * last_step_size
|
| 264 |
+
elif last_step == "Tweedie":
|
| 265 |
+
alpha = self.transport.path_sampler.compute_alpha_t # simple aliasing; the original name was too long
|
| 266 |
+
sigma = self.transport.path_sampler.compute_sigma_t
|
| 267 |
+
last_step_fn = \
|
| 268 |
+
lambda x, t, model, **model_kwargs: \
|
| 269 |
+
x / alpha(t)[0][0] + (sigma(t)[0][0] ** 2) / alpha(t)[0][0] * self.score(x, t, model, **model_kwargs)
|
| 270 |
+
elif last_step == "Euler":
|
| 271 |
+
last_step_fn = \
|
| 272 |
+
lambda x, t, model, **model_kwargs: \
|
| 273 |
+
x + self.drift(x, t, model, **model_kwargs) * last_step_size
|
| 274 |
+
else:
|
| 275 |
+
raise NotImplementedError()
|
| 276 |
+
|
| 277 |
+
return last_step_fn
|
| 278 |
+
|
| 279 |
+
def sample_sde(
|
| 280 |
+
self,
|
| 281 |
+
*,
|
| 282 |
+
sampling_method="Euler",
|
| 283 |
+
diffusion_form="SBDM",
|
| 284 |
+
diffusion_norm=1.0,
|
| 285 |
+
last_step="Mean",
|
| 286 |
+
last_step_size=0.04,
|
| 287 |
+
num_steps=250,
|
| 288 |
+
):
|
| 289 |
+
"""returns a sampling function with given SDE settings
|
| 290 |
+
Args:
|
| 291 |
+
- sampling_method: type of sampler used in solving the SDE; default to be Euler-Maruyama
|
| 292 |
+
- diffusion_form: function form of diffusion coefficient; default to be matching SBDM
|
| 293 |
+
- diffusion_norm: function magnitude of diffusion coefficient; default to 1
|
| 294 |
+
- last_step: type of the last step; default to identity
|
| 295 |
+
- last_step_size: size of the last step; default to match the stride of 250 steps over [0,1]
|
| 296 |
+
- num_steps: total integration step of SDE
|
| 297 |
+
"""
|
| 298 |
+
|
| 299 |
+
if last_step is None:
|
| 300 |
+
last_step_size = 0.0
|
| 301 |
+
|
| 302 |
+
sde_drift, sde_diffusion = self.__get_sde_diffusion_and_drift(
|
| 303 |
+
diffusion_form=diffusion_form,
|
| 304 |
+
diffusion_norm=diffusion_norm,
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
t0, t1 = self.transport.check_interval(
|
| 308 |
+
self.transport.train_eps,
|
| 309 |
+
self.transport.sample_eps,
|
| 310 |
+
diffusion_form=diffusion_form,
|
| 311 |
+
sde=True,
|
| 312 |
+
eval=True,
|
| 313 |
+
reverse=False,
|
| 314 |
+
last_step_size=last_step_size,
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
_sde = sde(
|
| 318 |
+
sde_drift,
|
| 319 |
+
sde_diffusion,
|
| 320 |
+
t0=t0,
|
| 321 |
+
t1=t1,
|
| 322 |
+
num_steps=num_steps,
|
| 323 |
+
sampler_type=sampling_method
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
last_step_fn = self.__get_last_step(sde_drift, last_step=last_step, last_step_size=last_step_size)
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
def _sample(init, model, **model_kwargs):
|
| 330 |
+
xs = _sde.sample(init, model, **model_kwargs)
|
| 331 |
+
ts = th.ones(init.size(0), device=init.device) * t1
|
| 332 |
+
x = last_step_fn(xs[-1], ts, model, **model_kwargs)
|
| 333 |
+
xs.append(x)
|
| 334 |
+
|
| 335 |
+
assert len(xs) == num_steps, "Samples does not match the number of steps"
|
| 336 |
+
|
| 337 |
+
return xs
|
| 338 |
+
|
| 339 |
+
return _sample
|
| 340 |
+
|
| 341 |
+
def sample_ode(
|
| 342 |
+
self,
|
| 343 |
+
*,
|
| 344 |
+
sampling_method="dopri5",
|
| 345 |
+
num_steps=50,
|
| 346 |
+
atol=1e-6,
|
| 347 |
+
rtol=1e-3,
|
| 348 |
+
reverse=False,
|
| 349 |
+
):
|
| 350 |
+
"""returns a sampling function with given ODE settings
|
| 351 |
+
Args:
|
| 352 |
+
- sampling_method: type of sampler used in solving the ODE; default to be Dopri5
|
| 353 |
+
- num_steps:
|
| 354 |
+
- fixed solver (Euler, Heun): the actual number of integration steps performed
|
| 355 |
+
- adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation
|
| 356 |
+
- atol: absolute error tolerance for the solver
|
| 357 |
+
- rtol: relative error tolerance for the solver
|
| 358 |
+
- reverse: whether solving the ODE in reverse (data to noise); default to False
|
| 359 |
+
"""
|
| 360 |
+
drift = self.drift
|
| 361 |
+
|
| 362 |
+
t0, t1 = self.transport.check_interval(
|
| 363 |
+
self.transport.train_eps,
|
| 364 |
+
self.transport.sample_eps,
|
| 365 |
+
sde=False,
|
| 366 |
+
eval=True,
|
| 367 |
+
reverse=reverse,
|
| 368 |
+
last_step_size=0.0,
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
_ode = ode(
|
| 372 |
+
drift=drift,
|
| 373 |
+
t0=t0,
|
| 374 |
+
t1=t1,
|
| 375 |
+
sampler_type=sampling_method,
|
| 376 |
+
num_steps=num_steps,
|
| 377 |
+
atol=atol,
|
| 378 |
+
rtol=rtol,
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
return _ode.sample
|
| 382 |
+
|
| 383 |
+
def sample_ode_likelihood(
|
| 384 |
+
self,
|
| 385 |
+
*,
|
| 386 |
+
sampling_method="dopri5",
|
| 387 |
+
num_steps=50,
|
| 388 |
+
atol=1e-6,
|
| 389 |
+
rtol=1e-3,
|
| 390 |
+
):
|
| 391 |
+
|
| 392 |
+
"""returns a sampling function for calculating likelihood with given ODE settings
|
| 393 |
+
Args:
|
| 394 |
+
- sampling_method: type of sampler used in solving the ODE; default to be Dopri5
|
| 395 |
+
- num_steps:
|
| 396 |
+
- fixed solver (Euler, Heun): the actual number of integration steps performed
|
| 397 |
+
- adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation
|
| 398 |
+
- atol: absolute error tolerance for the solver
|
| 399 |
+
- rtol: relative error tolerance for the solver
|
| 400 |
+
"""
|
| 401 |
+
def _likelihood_drift(x, t, model, **model_kwargs):
|
| 402 |
+
x, _ = x
|
| 403 |
+
eps = th.randint(2, x.size(), dtype=th.float, device=x.device) * 2 - 1
|
| 404 |
+
t = th.ones_like(t) * (1 - t)
|
| 405 |
+
with th.enable_grad():
|
| 406 |
+
x.requires_grad = True
|
| 407 |
+
grad = th.autograd.grad(th.sum(self.drift(x, t, model, **model_kwargs) * eps), x)[0]
|
| 408 |
+
logp_grad = th.sum(grad * eps, dim=tuple(range(1, len(x.size()))))
|
| 409 |
+
drift = self.drift(x, t, model, **model_kwargs)
|
| 410 |
+
return (-drift, logp_grad)
|
| 411 |
+
|
| 412 |
+
t0, t1 = self.transport.check_interval(
|
| 413 |
+
self.transport.train_eps,
|
| 414 |
+
self.transport.sample_eps,
|
| 415 |
+
sde=False,
|
| 416 |
+
eval=True,
|
| 417 |
+
reverse=False,
|
| 418 |
+
last_step_size=0.0,
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
_ode = ode(
|
| 422 |
+
drift=_likelihood_drift,
|
| 423 |
+
t0=t0,
|
| 424 |
+
t1=t1,
|
| 425 |
+
sampler_type=sampling_method,
|
| 426 |
+
num_steps=num_steps,
|
| 427 |
+
atol=atol,
|
| 428 |
+
rtol=rtol,
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
def _sample_fn(x, model, **model_kwargs):
|
| 432 |
+
init_logp = th.zeros(x.size(0)).to(x)
|
| 433 |
+
input = (x, init_logp)
|
| 434 |
+
drift, delta_logp = _ode.sample(input, model, **model_kwargs)
|
| 435 |
+
drift, delta_logp = drift[-1], delta_logp[-1]
|
| 436 |
+
prior_logp = self.transport.prior_logp(drift)
|
| 437 |
+
logp = prior_logp - delta_logp
|
| 438 |
+
return logp, drift
|
| 439 |
+
|
| 440 |
+
return _sample_fn
|
SiT_back/SiT_clean/transport/utils.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch as th
|
| 2 |
+
|
| 3 |
+
class EasyDict:
|
| 4 |
+
|
| 5 |
+
def __init__(self, sub_dict):
|
| 6 |
+
for k, v in sub_dict.items():
|
| 7 |
+
setattr(self, k, v)
|
| 8 |
+
|
| 9 |
+
def __getitem__(self, key):
|
| 10 |
+
return getattr(self, key)
|
| 11 |
+
|
| 12 |
+
def mean_flat(x):
|
| 13 |
+
"""
|
| 14 |
+
Take the mean over all non-batch dimensions.
|
| 15 |
+
"""
|
| 16 |
+
return th.mean(x, dim=list(range(1, len(x.size()))))
|
| 17 |
+
|
| 18 |
+
def log_state(state):
|
| 19 |
+
result = []
|
| 20 |
+
|
| 21 |
+
sorted_state = dict(sorted(state.items()))
|
| 22 |
+
for key, value in sorted_state.items():
|
| 23 |
+
# Check if the value is an instance of a class
|
| 24 |
+
if "<object" in str(value) or "object at" in str(value):
|
| 25 |
+
result.append(f"{key}: [{value.__class__.__name__}]")
|
| 26 |
+
else:
|
| 27 |
+
result.append(f"{key}: {value}")
|
| 28 |
+
|
| 29 |
+
return '\n'.join(result)
|
SiT_back/SiT_clean/wandb_utils.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import wandb
|
| 2 |
+
import torch
|
| 3 |
+
from torchvision.utils import make_grid
|
| 4 |
+
import torch.distributed as dist
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import os
|
| 7 |
+
import argparse
|
| 8 |
+
import hashlib
|
| 9 |
+
import math
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def is_main_process():
|
| 13 |
+
return dist.get_rank() == 0
|
| 14 |
+
|
| 15 |
+
def namespace_to_dict(namespace):
|
| 16 |
+
return {
|
| 17 |
+
k: namespace_to_dict(v) if isinstance(v, argparse.Namespace) else v
|
| 18 |
+
for k, v in vars(namespace).items()
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def generate_run_id(exp_name):
|
| 23 |
+
# https://stackoverflow.com/questions/16008670/how-to-hash-a-string-into-8-digits
|
| 24 |
+
return str(int(hashlib.sha256(exp_name.encode('utf-8')).hexdigest(), 16) % 10 ** 8)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def initialize(args, entity, exp_name, project_name):
|
| 28 |
+
config_dict = namespace_to_dict(args)
|
| 29 |
+
wandb.login(key=os.environ["WANDB_KEY"])
|
| 30 |
+
wandb.init(
|
| 31 |
+
entity=entity,
|
| 32 |
+
project=project_name,
|
| 33 |
+
name=exp_name,
|
| 34 |
+
config=config_dict,
|
| 35 |
+
id=generate_run_id(exp_name),
|
| 36 |
+
resume="allow",
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def log(stats, step=None):
|
| 41 |
+
if is_main_process():
|
| 42 |
+
wandb.log({k: v for k, v in stats.items()}, step=step)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def log_image(sample, step=None):
|
| 46 |
+
if is_main_process():
|
| 47 |
+
sample = array2grid(sample)
|
| 48 |
+
wandb.log({f"samples": wandb.Image(sample), "train_step": step})
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def array2grid(x):
|
| 52 |
+
nrow = round(math.sqrt(x.size(0)))
|
| 53 |
+
x = make_grid(x, nrow=nrow, normalize=True, value_range=(-1,1))
|
| 54 |
+
x = x.mul(255).add_(0.5).clamp_(0,255).permute(1,2,0).to('cpu', torch.uint8).numpy()
|
| 55 |
+
return x
|