xiangzai commited on
Commit
d31b843
·
verified ·
1 Parent(s): 9734e98

Add files using upload-large-folder tool

Browse files
Files changed (50) hide show
  1. Rectified_Noise/GVP-Disp/W_No.log +5 -0
  2. Rectified_Noise/GVP-Disp/W_True_0.15.log +5 -0
  3. Rectified_Noise/GVP-Disp/W_True_0.5.log +5 -0
  4. Rectified_Noise/GVP-Disp/download.py +41 -0
  5. Rectified_Noise/GVP-Disp/environment.yml +16 -0
  6. Rectified_Noise/GVP-Disp/models.py +647 -0
  7. Rectified_Noise/GVP-Disp/sample_ddp.py +233 -0
  8. Rectified_Noise/GVP-Disp/sample_rectified_noise.py +380 -0
  9. Rectified_Noise/GVP-Disp/train_utils.py +35 -0
  10. Rectified_Noise/GVP-Disp/w_training1_VP.log +628 -0
  11. Rectified_Noise/GVP-Disp/权重类型分析.md +133 -0
  12. Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000032.png +0 -0
  13. Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000077.png +0 -0
  14. Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000133.png +0 -0
  15. Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000161.png +0 -0
  16. Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000220.png +0 -0
  17. Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000331.png +0 -0
  18. Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000387.png +0 -0
  19. Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000505.png +0 -0
  20. Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000517.png +0 -0
  21. Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000551.png +0 -0
  22. Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000726.png +0 -0
  23. Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000817.png +0 -0
  24. Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000865.png +0 -0
  25. Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000914.png +0 -0
  26. Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000940.png +0 -0
  27. Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/001043.png +0 -0
  28. Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/001210.png +0 -0
  29. SiT_back/SiT_clean/W_training.log +110 -0
  30. SiT_back/SiT_clean/__pycache__/download.cpython-312.pyc +0 -0
  31. SiT_back/SiT_clean/__pycache__/models.cpython-312.pyc +0 -0
  32. SiT_back/SiT_clean/__pycache__/train_utils.cpython-312.pyc +0 -0
  33. SiT_back/SiT_clean/download.py +40 -0
  34. SiT_back/SiT_clean/models.py +370 -0
  35. SiT_back/SiT_clean/run.sh +0 -0
  36. SiT_back/SiT_clean/sample.py +144 -0
  37. SiT_back/SiT_clean/sample_ddp.py +233 -0
  38. SiT_back/SiT_clean/train.py +371 -0
  39. SiT_back/SiT_clean/train_utils.py +32 -0
  40. SiT_back/SiT_clean/transport/__init__.py +65 -0
  41. SiT_back/SiT_clean/transport/__pycache__/__init__.cpython-312.pyc +0 -0
  42. SiT_back/SiT_clean/transport/__pycache__/integrators.cpython-312.pyc +0 -0
  43. SiT_back/SiT_clean/transport/__pycache__/path.cpython-312.pyc +0 -0
  44. SiT_back/SiT_clean/transport/__pycache__/transport.cpython-312.pyc +0 -0
  45. SiT_back/SiT_clean/transport/__pycache__/utils.cpython-312.pyc +0 -0
  46. SiT_back/SiT_clean/transport/integrators.py +115 -0
  47. SiT_back/SiT_clean/transport/path.py +192 -0
  48. SiT_back/SiT_clean/transport/transport.py +440 -0
  49. SiT_back/SiT_clean/transport/utils.py +29 -0
  50. SiT_back/SiT_clean/wandb_utils.py +55 -0
Rectified_Noise/GVP-Disp/W_No.log ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
0
  0%| | 0/47 [00:00<?, ?it/s]
1
  2%|▏ | 1/47 [02:10<1:39:52, 130.26s/it]
2
  4%|▍ | 2/47 [04:19<1:37:18, 129.75s/it]
3
  6%|▋ | 3/47 [06:29<1:35:03, 129.63s/it]
4
  9%|▊ | 4/47 [08:38<1:32:51, 129.58s/it]
5
  11%|█ | 5/47 [10:48<1:30:41, 129.55s/it]
6
  13%|█▎ | 6/47 [12:57<1:28:31, 129.54s/it]
7
  15%|█▍ | 7/47 [15:07<1:26:22, 129.56s/it]
8
  17%|█▋ | 8/47 [17:14<1:23:44, 128.82s/it]
9
  19%|█▉ | 9/47 [19:20<1:21:01, 127.93s/it]
10
  21%|██▏ | 10/47 [21:29<1:19:01, 128.15s/it]
11
  23%|██▎ | 11/47 [23:38<1:17:05, 128.47s/it]
12
  26%|██▌ | 12/47 [25:47<1:15:03, 128.67s/it]
13
  28%|██▊ | 13/47 [27:56<1:13:02, 128.91s/it]
 
1
+ [NOTICE] The application is pending for GPU resource in asynchronous queue. The longest waiting time in queue is 1800 seconds.
2
+ Starting rank=0, seed=0, world_size=1.
3
+ Saving .png samples at GVP_samples/depth-mu-2-threshold-0.0-0025000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04
4
+ Total number of images that will be sampled: 3008
5
+
6
  0%| | 0/47 [00:00<?, ?it/s]
7
  2%|▏ | 1/47 [02:10<1:39:52, 130.26s/it]
8
  4%|▍ | 2/47 [04:19<1:37:18, 129.75s/it]
9
  6%|▋ | 3/47 [06:29<1:35:03, 129.63s/it]
10
  9%|▊ | 4/47 [08:38<1:32:51, 129.58s/it]
11
  11%|█ | 5/47 [10:48<1:30:41, 129.55s/it]
12
  13%|█▎ | 6/47 [12:57<1:28:31, 129.54s/it]
13
  15%|█▍ | 7/47 [15:07<1:26:22, 129.56s/it]
14
  17%|█▋ | 8/47 [17:14<1:23:44, 128.82s/it]
15
  19%|█▉ | 9/47 [19:20<1:21:01, 127.93s/it]
16
  21%|██▏ | 10/47 [21:29<1:19:01, 128.15s/it]
17
  23%|██▎ | 11/47 [23:38<1:17:05, 128.47s/it]
18
  26%|██▌ | 12/47 [25:47<1:15:03, 128.67s/it]
19
  28%|██▊ | 13/47 [27:56<1:13:02, 128.91s/it]
Rectified_Noise/GVP-Disp/W_True_0.15.log ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
0
  0%| | 0/47 [00:00<?, ?it/s]
1
  2%|▏ | 1/47 [01:34<1:12:45, 94.90s/it]
2
  4%|▍ | 2/47 [03:08<1:10:43, 94.31s/it]
3
  6%|▋ | 3/47 [04:42<1:09:03, 94.17s/it]
4
  9%|▊ | 4/47 [06:16<1:07:23, 94.04s/it]
5
  11%|█ | 5/47 [07:50<1:05:47, 93.99s/it]
6
  13%|█▎ | 6/47 [09:24<1:04:09, 93.88s/it]
7
  15%|█▍ | 7/47 [10:58<1:02:34, 93.85s/it]
8
  17%|█▋ | 8/47 [12:31<1:00:57, 93.79s/it]
9
  19%|█▉ | 9/47 [14:05<59:23, 93.77s/it]
10
  21%|██▏ | 10/47 [15:39<57:48, 93.75s/it]
11
  23%|██▎ | 11/47 [17:12<56:15, 93.76s/it]
12
  26%|██▌ | 12/47 [18:46<54:41, 93.76s/it]
13
  28%|██▊ | 13/47 [20:20<53:08, 93.78s/it]
14
  30%|██▉ | 14/47 [21:54<51:36, 93.82s/it]
15
  32%|███▏ | 15/47 [23:28<50:01, 93.80s/it]
16
  34%|███▍ | 16/47 [25:01<48:25, 93.71s/it]
17
  36%|███▌ | 17/47 [26:35<46:53, 93.77s/it]
18
  38%|███▊ | 18/47 [28:09<45:19, 93.78s/it]
19
  40%|████ | 19/47 [29:43<43:45, 93.77s/it]
 
1
+ [NOTICE] The application is pending for GPU resource in asynchronous queue. The longest waiting time in queue is 1800 seconds.
2
+ Starting rank=0, seed=0, world_size=1.
3
+ Saving .png samples at GVP_samples/depth-mu-2-threshold-0.15-0025000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04
4
+ Total number of images that will be sampled: 3008
5
+
6
  0%| | 0/47 [00:00<?, ?it/s]
7
  2%|▏ | 1/47 [01:34<1:12:45, 94.90s/it]
8
  4%|▍ | 2/47 [03:08<1:10:43, 94.31s/it]
9
  6%|▋ | 3/47 [04:42<1:09:03, 94.17s/it]
10
  9%|▊ | 4/47 [06:16<1:07:23, 94.04s/it]
11
  11%|█ | 5/47 [07:50<1:05:47, 93.99s/it]
12
  13%|█▎ | 6/47 [09:24<1:04:09, 93.88s/it]
13
  15%|█▍ | 7/47 [10:58<1:02:34, 93.85s/it]
14
  17%|█▋ | 8/47 [12:31<1:00:57, 93.79s/it]
15
  19%|█▉ | 9/47 [14:05<59:23, 93.77s/it]
16
  21%|██▏ | 10/47 [15:39<57:48, 93.75s/it]
17
  23%|██▎ | 11/47 [17:12<56:15, 93.76s/it]
18
  26%|██▌ | 12/47 [18:46<54:41, 93.76s/it]
19
  28%|██▊ | 13/47 [20:20<53:08, 93.78s/it]
20
  30%|██▉ | 14/47 [21:54<51:36, 93.82s/it]
21
  32%|███▏ | 15/47 [23:28<50:01, 93.80s/it]
22
  34%|███▍ | 16/47 [25:01<48:25, 93.71s/it]
23
  36%|███▌ | 17/47 [26:35<46:53, 93.77s/it]
24
  38%|███▊ | 18/47 [28:09<45:19, 93.78s/it]
25
  40%|████ | 19/47 [29:43<43:45, 93.77s/it]
Rectified_Noise/GVP-Disp/W_True_0.5.log ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
0
  0%| | 0/47 [00:00<?, ?it/s]
1
  2%|▏ | 1/47 [01:34<1:12:36, 94.71s/it]
2
  4%|▍ | 2/47 [03:08<1:10:34, 94.11s/it]
3
  6%|▋ | 3/47 [04:42<1:08:54, 93.97s/it]
4
  9%|▊ | 4/47 [06:15<1:07:14, 93.82s/it]
5
  11%|█ | 5/47 [07:49<1:05:39, 93.79s/it]
6
  13%|█▎ | 6/47 [09:23<1:04:03, 93.75s/it]
7
  15%|█▍ | 7/47 [10:57<1:02:31, 93.79s/it]
8
  17%|█▋ | 8/47 [12:30<1:00:56, 93.75s/it]
9
  19%|█▉ | 9/47 [14:04<59:21, 93.73s/it]
10
  21%|██▏ | 10/47 [15:38<57:47, 93.73s/it]
11
  23%|██▎ | 11/47 [17:11<56:14, 93.74s/it]
12
  26%|██▌ | 12/47 [18:45<54:40, 93.74s/it]
13
  28%|██▊ | 13/47 [20:19<53:06, 93.72s/it]
14
  30%|██▉ | 14/47 [21:53<51:34, 93.77s/it]
15
  32%|███▏ | 15/47 [23:26<50:00, 93.77s/it]
16
  34%|███▍ | 16/47 [25:00<48:24, 93.70s/it]
17
  36%|███▌ | 17/47 [26:34<46:53, 93.77s/it]
18
  38%|███▊ | 18/47 [28:08<45:19, 93.79s/it]
19
  40%|████ | 19/47 [29:42<43:45, 93.78s/it]
 
1
+ [NOTICE] The application is pending for GPU resource in asynchronous queue. The longest waiting time in queue is 1800 seconds.
2
+ Starting rank=0, seed=0, world_size=1.
3
+ Saving .png samples at GVP_samples/depth-mu-2-threshold-0.5-0025000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04
4
+ Total number of images that will be sampled: 3008
5
+
6
  0%| | 0/47 [00:00<?, ?it/s]
7
  2%|▏ | 1/47 [01:34<1:12:36, 94.71s/it]
8
  4%|▍ | 2/47 [03:08<1:10:34, 94.11s/it]
9
  6%|▋ | 3/47 [04:42<1:08:54, 93.97s/it]
10
  9%|▊ | 4/47 [06:15<1:07:14, 93.82s/it]
11
  11%|█ | 5/47 [07:49<1:05:39, 93.79s/it]
12
  13%|█▎ | 6/47 [09:23<1:04:03, 93.75s/it]
13
  15%|█▍ | 7/47 [10:57<1:02:31, 93.79s/it]
14
  17%|█▋ | 8/47 [12:30<1:00:56, 93.75s/it]
15
  19%|█▉ | 9/47 [14:04<59:21, 93.73s/it]
16
  21%|██▏ | 10/47 [15:38<57:47, 93.73s/it]
17
  23%|██▎ | 11/47 [17:11<56:14, 93.74s/it]
18
  26%|██▌ | 12/47 [18:45<54:40, 93.74s/it]
19
  28%|██▊ | 13/47 [20:19<53:06, 93.72s/it]
20
  30%|██▉ | 14/47 [21:53<51:34, 93.77s/it]
21
  32%|███▏ | 15/47 [23:26<50:00, 93.77s/it]
22
  34%|███▍ | 16/47 [25:00<48:24, 93.70s/it]
23
  36%|███▌ | 17/47 [26:34<46:53, 93.77s/it]
24
  38%|███▊ | 18/47 [28:08<45:19, 93.79s/it]
25
  40%|████ | 19/47 [29:42<43:45, 93.78s/it]
Rectified_Noise/GVP-Disp/download.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This source code is licensed under the license found in the
2
+ # LICENSE file in the root directory of this source tree.
3
+
4
+ """
5
+ Functions for downloading pre-trained SiT models
6
+ """
7
+ from torchvision.datasets.utils import download_url
8
+ import torch
9
+ import os
10
+
11
+
12
+ pretrained_models = {'SiT-XL-2-256x256.pt'}
13
+
14
+
15
+ def find_model(model_name):
16
+ """
17
+ Finds a pre-trained SiT model, downloading it if necessary. Alternatively, loads a model from a local path.
18
+ """
19
+ if model_name in pretrained_models:
20
+ return download_model(model_name)
21
+ else:
22
+ assert os.path.isfile(model_name), f'Could not find SiT checkpoint at {model_name}'
23
+ checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage, weights_only=False)
24
+ if "ema" in checkpoint: # supports checkpoints from train.py
25
+ checkpoint = checkpoint["ema"]
26
+ return checkpoint
27
+
28
+
29
+ def download_model(model_name):
30
+ """
31
+ Downloads a pre-trained SiT model from the web.
32
+ """
33
+ assert model_name in pretrained_models
34
+ local_path = f'pretrained_models/{model_name}'
35
+ if not os.path.isfile(local_path):
36
+ os.makedirs('pretrained_models', exist_ok=True)
37
+ web_path = f'https://www.dl.dropboxusercontent.com/scl/fi/as9oeomcbub47de5g4be0/SiT-XL-2-256.pt?rlkey=uxzxmpicu46coq3msb17b9ofa&dl=0'
38
+ download_url(web_path, 'pretrained_models', filename=model_name)
39
+ model = torch.load(local_path, map_location=lambda storage, loc: storage, weights_only=False)
40
+ return model
41
+
Rectified_Noise/GVP-Disp/environment.yml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: RN
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ dependencies:
6
+ - python >= 3.8
7
+ - pytorch >= 1.13
8
+ - torchvision
9
+ - pytorch-cuda >=11.7
10
+ - pip
11
+ - pip:
12
+ - timm
13
+ - diffusers
14
+ - accelerate
15
+ - torchdiffeq
16
+ - wandb
Rectified_Noise/GVP-Disp/models.py ADDED
@@ -0,0 +1,647 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This source code is licensed under the license found in the
2
+ # LICENSE file in the root directory of this source tree.
3
+ # --------------------------------------------------------
4
+ # References:
5
+ # GLIDE: https://github.com/openai/glide-text2im
6
+ # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
7
+ # --------------------------------------------------------
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import numpy as np
12
+ import math
13
+ from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
14
+
15
+
16
+ def modulate(x, shift, scale):
17
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
18
+
19
+
20
+ #################################################################################
21
+ # Embedding Layers for Timesteps and Class Labels #
22
+ #################################################################################
23
+
24
+ class TimestepEmbedder(nn.Module):
25
+ """
26
+ Embeds scalar timesteps into vector representations.
27
+ """
28
+ def __init__(self, hidden_size, frequency_embedding_size=256):
29
+ super().__init__()
30
+ self.mlp = nn.Sequential(
31
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
32
+ nn.SiLU(),
33
+ nn.Linear(hidden_size, hidden_size, bias=True),
34
+ )
35
+ self.frequency_embedding_size = frequency_embedding_size
36
+
37
+ @staticmethod
38
+ def timestep_embedding(t, dim, max_period=10000):
39
+ """
40
+ Create sinusoidal timestep embeddings.
41
+ :param t: a 1-D Tensor of N indices, one per batch element.
42
+ These may be fractional.
43
+ :param dim: the dimension of the output.
44
+ :param max_period: controls the minimum frequency of the embeddings.
45
+ :return: an (N, D) Tensor of positional embeddings.
46
+ """
47
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
48
+ half = dim // 2
49
+ freqs = torch.exp(
50
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
51
+ ).to(device=t.device)
52
+ args = t[:, None].float() * freqs[None]
53
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
54
+ if dim % 2:
55
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
56
+ return embedding
57
+
58
+ def forward(self, t):
59
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
60
+ t_emb = self.mlp(t_freq)
61
+ return t_emb
62
+
63
+
64
+ class LabelEmbedder(nn.Module):
65
+ """
66
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
67
+ """
68
+ def __init__(self, num_classes, hidden_size, dropout_prob):
69
+ super().__init__()
70
+ use_cfg_embedding = dropout_prob > 0
71
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
72
+ self.num_classes = num_classes
73
+ self.dropout_prob = dropout_prob
74
+
75
+ def token_drop(self, labels, force_drop_ids=None):
76
+ """
77
+ Drops labels to enable classifier-free guidance.
78
+ """
79
+ if force_drop_ids is None:
80
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
81
+ else:
82
+ drop_ids = force_drop_ids == 1
83
+ labels = torch.where(drop_ids, self.num_classes, labels)
84
+ return labels
85
+
86
+ def forward(self, labels, train, force_drop_ids=None):
87
+ use_dropout = self.dropout_prob > 0
88
+ if (train and use_dropout) or (force_drop_ids is not None):
89
+ labels = self.token_drop(labels, force_drop_ids)
90
+ embeddings = self.embedding_table(labels)
91
+ return embeddings
92
+
93
+
94
+ #################################################################################
95
+ # Core SiT Model #
96
+ #################################################################################
97
+
98
+ class SiTBlock(nn.Module):
99
+ """
100
+ A SiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
101
+ """
102
+ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
103
+ super().__init__()
104
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
105
+ self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
106
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
107
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
108
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
109
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
110
+ self.adaLN_modulation = nn.Sequential(
111
+ nn.SiLU(),
112
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
113
+ )
114
+
115
+ def forward(self, x, c):
116
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
117
+ x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
118
+ x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
119
+ return x
120
+
121
+
122
+ class FinalLayer(nn.Module):
123
+ """
124
+ The final layer of SiT.
125
+ """
126
+ def __init__(self, hidden_size, patch_size, out_channels):
127
+ super().__init__()
128
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
129
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
130
+ self.adaLN_modulation = nn.Sequential(
131
+ nn.SiLU(),
132
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True)
133
+ )
134
+
135
+ def forward(self, x, c):
136
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
137
+ x = modulate(self.norm_final(x), shift, scale)
138
+ x = self.linear(x)
139
+ return x
140
+
141
+
142
+ class SiT(nn.Module):
143
+ """
144
+ Diffusion model with a Transformer backbone.
145
+ """
146
+ def __init__(
147
+ self,
148
+ input_size=32,
149
+ patch_size=2,
150
+ in_channels=4,
151
+ hidden_size=1152,
152
+ depth=28,
153
+ num_heads=16,
154
+ mlp_ratio=4.0,
155
+ class_dropout_prob=0.1,
156
+ num_classes=1000,
157
+ learn_sigma=True,
158
+ ):
159
+ super().__init__()
160
+ self.learn_sigma = learn_sigma
161
+ self.learn_sigma = True
162
+ self.in_channels = in_channels
163
+ self.out_channels = in_channels * 2
164
+ self.patch_size = patch_size
165
+ self.num_heads = num_heads
166
+
167
+ self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
168
+ self.t_embedder = TimestepEmbedder(hidden_size)
169
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
170
+ num_patches = self.x_embedder.num_patches
171
+ # Will use fixed sin-cos embedding:
172
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
173
+
174
+ self.blocks = nn.ModuleList([
175
+ SiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
176
+ ])
177
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
178
+ self.initialize_weights()
179
+
180
+ def initialize_weights(self):
181
+ # Initialize transformer layers:
182
+ def _basic_init(module):
183
+ if isinstance(module, nn.Linear):
184
+ torch.nn.init.xavier_uniform_(module.weight)
185
+ if module.bias is not None:
186
+ nn.init.constant_(module.bias, 0)
187
+ self.apply(_basic_init)
188
+
189
+ # Initialize (and freeze) pos_embed by sin-cos embedding:
190
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
191
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
192
+
193
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
194
+ w = self.x_embedder.proj.weight.data
195
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
196
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
197
+
198
+ # Initialize label embedding table:
199
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
200
+
201
+ # Initialize timestep embedding MLP:
202
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
203
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
204
+
205
+ # Zero-out adaLN modulation layers in SiT blocks:
206
+ for block in self.blocks:
207
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
208
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
209
+
210
+ # Zero-out output layers:
211
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
212
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
213
+ nn.init.constant_(self.final_layer.linear.weight, 0)
214
+ nn.init.constant_(self.final_layer.linear.bias, 0)
215
+
216
+ def unpatchify(self, x):
217
+ """
218
+ x: (N, T, patch_size**2 * C)
219
+ imgs: (N, H, W, C)
220
+ """
221
+ c = self.out_channels
222
+ p = self.x_embedder.patch_size[0]
223
+ h = w = int(x.shape[1] ** 0.5)
224
+ assert h * w == x.shape[1]
225
+
226
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
227
+ x = torch.einsum('nhwpqc->nchpwq', x)
228
+ imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
229
+ return imgs
230
+
231
+ def forward(self, x, t, y, return_act=False):
232
+ """
233
+ Forward pass of SiT.
234
+ x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
235
+ t: (N,) tensor of diffusion timesteps
236
+ y: (N,) tensor of class labels
237
+ return_act: if True, return activations from transformer blocks
238
+ """
239
+ act = []
240
+ x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
241
+ t = self.t_embedder(t) # (N, D)
242
+ y = self.y_embedder(y, self.training) # (N, D)
243
+ c = t + y # (N, D)
244
+ for block in self.blocks:
245
+ x = block(x, c) # (N, T, D)
246
+ if return_act:
247
+ act.append(x)
248
+ x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
249
+ x = self.unpatchify(x) # (N, out_channels, H, W)
250
+ if self.learn_sigma:
251
+ x, _ = x.chunk(2, dim=1)
252
+ if return_act:
253
+ return x, act
254
+ return x
255
+
256
+ def forward_with_cfg(self, x, t, y, cfg_scale):
257
+ """
258
+ Forward pass of SiT, but also batches the unconSiTional forward pass for classifier-free guidance.
259
+ """
260
+ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
261
+ half = x[: len(x) // 2]
262
+ combined = torch.cat([half, half], dim=0)
263
+ model_out = self.forward(combined, t, y)
264
+ # For exact reproducibility reasons, we apply classifier-free guidance on only
265
+ # three channels by default. The standard approach to cfg applies it to all channels.
266
+ # This can be done by uncommenting the following line and commenting-out the line following that.
267
+ # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
268
+ eps, rest = model_out[:, :3], model_out[:, 3:]
269
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
270
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
271
+ eps = torch.cat([half_eps, half_eps], dim=0)
272
+ return torch.cat([eps, rest], dim=1)
273
+
274
+
275
+ #################################################################################
276
+ # Sine/Cosine Positional Embedding Functions #
277
+ #################################################################################
278
+ # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
279
+
280
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
281
+ """
282
+ grid_size: int of the grid height and width
283
+ return:
284
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
285
+ """
286
+ grid_h = np.arange(grid_size, dtype=np.float32)
287
+ grid_w = np.arange(grid_size, dtype=np.float32)
288
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
289
+ grid = np.stack(grid, axis=0)
290
+
291
+ grid = grid.reshape([2, 1, grid_size, grid_size])
292
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
293
+ if cls_token and extra_tokens > 0:
294
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
295
+ return pos_embed
296
+
297
+
298
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
299
+ assert embed_dim % 2 == 0
300
+
301
+ # use half of dimensions to encode grid_h
302
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
303
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
304
+
305
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
306
+ return emb
307
+
308
+
309
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
310
+ """
311
+ embed_dim: output dimension for each position
312
+ pos: a list of positions to be encoded: size (M,)
313
+ out: (M, D)
314
+ """
315
+ assert embed_dim % 2 == 0
316
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
317
+ omega /= embed_dim / 2.
318
+ omega = 1. / 10000**omega # (D/2,)
319
+
320
+ pos = pos.reshape(-1) # (M,)
321
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
322
+
323
+ emb_sin = np.sin(out) # (M, D/2)
324
+ emb_cos = np.cos(out) # (M, D/2)
325
+
326
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
327
+ return emb
328
+
329
+
330
+ #################################################################################
331
+ # SiT Configs #
332
+ #################################################################################
333
+
334
+ def SiT_XL_2(**kwargs):
335
+ return SiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
336
+
337
+ def SiT_XL_4(**kwargs):
338
+ return SiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs)
339
+
340
+ def SiT_XL_8(**kwargs):
341
+ return SiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs)
342
+
343
+ def SiT_L_2(**kwargs):
344
+ return SiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs)
345
+
346
+ def SiT_L_4(**kwargs):
347
+ return SiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs)
348
+
349
+ def SiT_L_8(**kwargs):
350
+ return SiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs)
351
+
352
+ def SiT_B_2(**kwargs):
353
+ return SiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs)
354
+
355
+ def SiT_B_4(**kwargs):
356
+ return SiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs)
357
+
358
+ def SiT_B_8(**kwargs):
359
+ return SiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs)
360
+
361
+ def SiT_S_2(**kwargs):
362
+ return SiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
363
+
364
+ def SiT_S_4(**kwargs):
365
+ return SiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs)
366
+
367
+ def SiT_S_8(**kwargs):
368
+ return SiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
369
+
370
+
371
+ SiT_models = {
372
+ 'SiT-XL/2': SiT_XL_2, 'SiT-XL/4': SiT_XL_4, 'SiT-XL/8': SiT_XL_8,
373
+ 'SiT-L/2': SiT_L_2, 'SiT-L/4': SiT_L_4, 'SiT-L/8': SiT_L_8,
374
+ 'SiT-B/2': SiT_B_2, 'SiT-B/4': SiT_B_4, 'SiT-B/8': SiT_B_8,
375
+ 'SiT-S/2': SiT_S_2, 'SiT-S/4': SiT_S_4, 'SiT-S/8': SiT_S_8,
376
+ }
377
+
378
+ #################################################################################
379
+ # SiTF1, SiTF2, CombinedModel #
380
+ #################################################################################
381
+
382
+ class SiTF1(nn.Module):
383
+ """
384
+ SiTF1 Model
385
+ """
386
+ def __init__(
387
+ self,
388
+ input_size=32,
389
+ patch_size=2,
390
+ in_channels=4,
391
+ hidden_size=1152,
392
+ depth=28,
393
+ num_heads=16,
394
+ mlp_ratio=4.0,
395
+ class_dropout_prob=0.1,
396
+ num_classes=1000,
397
+ learn_sigma=True,
398
+ final_layer=None,
399
+ ):
400
+ super().__init__()
401
+ self.input_size = input_size
402
+ self.patch_size= patch_size
403
+ self.hidden_size = hidden_size
404
+ self.in_channels = in_channels
405
+ self.out_channels = in_channels * 2
406
+ self.patch_size = patch_size
407
+ self.num_heads = num_heads
408
+ self.learn_sigma = learn_sigma
409
+ self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
410
+ self.t_embedder = TimestepEmbedder(hidden_size)
411
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
412
+ num_patches = self.x_embedder.num_patches
413
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
414
+ self.blocks = nn.ModuleList([
415
+ SiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
416
+ ])
417
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
418
+ self.initialize_weights()
419
+
420
+ def unpatchify(self, x):
421
+ """
422
+ x: (N, T, patch_size**2 * C)
423
+ imgs: (N, H, W, C)
424
+ """
425
+ c = self.out_channels
426
+ p = self.x_embedder.patch_size[0]
427
+ h = w = int(x.shape[1] ** 0.5)
428
+ assert h * w == x.shape[1]
429
+
430
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
431
+ x = torch.einsum('nhwpqc->nchpwq', x)
432
+ imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
433
+ return imgs
434
+
435
+ def initialize_weights(self):
436
+ def _basic_init(module):
437
+ if isinstance(module, nn.Linear):
438
+ torch.nn.init.xavier_uniform_(module.weight)
439
+ if module.bias is not None:
440
+ nn.init.constant_(module.bias, 0)
441
+ self.apply(_basic_init)
442
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
443
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
444
+ w = self.x_embedder.proj.weight.data
445
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
446
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
447
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
448
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
449
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
450
+ for block in self.blocks:
451
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
452
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
453
+
454
+ def forward(self, x, t, y):
455
+ x = self.x_embedder(x) + self.pos_embed
456
+ t = self.t_embedder(t)
457
+ y = self.y_embedder(y, self.training)
458
+ c = t + y
459
+ for block in self.blocks:
460
+ x = block(x, c)
461
+ x_now = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
462
+ x_now = self.unpatchify(x_now) # (N, out_channels, H, W)
463
+ x_now, _ = x_now.chunk(2, dim=1)
464
+ return x,x_now # patch token (N, T, D)
465
+
466
+ def forward_with_cfg(self, x, t, y, cfg_scale):
467
+ """
468
+ Forward pass with classifier-free guidance for SiTF1.
469
+ Applies guidance consistently to both patch tokens and image output (x_now).
470
+ """
471
+ # Take the first half (conditional inputs) and duplicate it so that
472
+ # it can be paired with conditional and unconditional labels in `y`.
473
+ half = x[: len(x) // 2]
474
+ combined = torch.cat([half, half], dim=0)
475
+ patch_tokens, x_now = self.forward(combined, t, y)
476
+
477
+ # Apply CFG on the image output channels (first 3 channels by default)
478
+ eps, rest = x_now[:, :3, ...], x_now[:, 3:, ...]
479
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
480
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
481
+ eps = torch.cat([half_eps, half_eps], dim=0)
482
+ x_now = torch.cat([eps, rest], dim=1)
483
+
484
+ # Apply same guidance logic to patch tokens so downstream modules see
485
+ # a consistent guided representation.
486
+ cond_tok, uncond_tok = torch.split(patch_tokens, len(patch_tokens) // 2, dim=0)
487
+ half_tok = uncond_tok + cfg_scale * (cond_tok - uncond_tok)
488
+ patch_tokens = torch.cat([half_tok, half_tok], dim=0)
489
+
490
+ return patch_tokens, x_now
491
+
492
+
493
+ class SiTF2(nn.Module):
494
+ """
495
+ SiTF2:
496
+ """
497
+ def __init__(
498
+ self,
499
+ input_size=32,
500
+ hidden_size=1152,
501
+ out_channels=8,
502
+ patch_size=2,
503
+ num_heads=16,
504
+ mlp_ratio=4.0,
505
+ depth=4,
506
+ learn_sigma=True,
507
+ final_layer=None,
508
+ num_classes=1000,
509
+ class_dropout_prob=0.1,
510
+ learn_mu=False,
511
+ ):
512
+ super().__init__()
513
+ self.learn_sigma = learn_sigma
514
+ self.learn_mu = learn_mu
515
+ self.out_channels = out_channels
516
+ self.in_channels = 4
517
+ self.patch_size = patch_size
518
+ self.num_heads = num_heads
519
+ self.blocks = nn.ModuleList([
520
+ SiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
521
+ ])
522
+ self.x_embedder = PatchEmbed(input_size, patch_size, self.in_channels, hidden_size, bias=True)
523
+ self.t_embedder = TimestepEmbedder(hidden_size)
524
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
525
+ num_patches = self.x_embedder.num_patches
526
+ self.num_patches = num_patches # Save original num_patches for unpatchify
527
+ # pos_embed needs to support 2*num_patches for concatenated input
528
+ self.pos_embed = nn.Parameter(torch.zeros(1, 2 * num_patches, hidden_size), requires_grad=False)
529
+ # Initialize pos_embed with sin-cos embedding
530
+ pos_embed = get_2d_sincos_pos_embed(hidden_size, int(num_patches ** 0.5))
531
+ # Repeat the pos_embed for both halves (or could use different embeddings)
532
+ pos_embed_full = np.concatenate([pos_embed, pos_embed], axis=0)
533
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed_full).float().unsqueeze(0))
534
+
535
+ if final_layer is not None:
536
+ self.final_layer = final_layer
537
+ else:
538
+ self.final_layer = FinalLayer(hidden_size, patch_size, out_channels)
539
+ # if depth !=0:
540
+ # for p in self.final_layer.parameters():
541
+ # if p is not None:
542
+ # torch.nn.init.constant_(p, 0)
543
+
544
+ def unpatchify(self, x, patch_size, out_channels):
545
+ c = out_channels
546
+ p = patch_size
547
+ # x.shape[1] might be 2*num_patches when using concatenated input
548
+ # Use original num_patches to calculate h and w
549
+ h = w = int(self.num_patches ** 0.5)
550
+ # If input has 2*num_patches, we need to handle it
551
+ if x.shape[1] == 2 * self.num_patches:
552
+ # Take only the first half (or average, or other strategy)
553
+ # For now, we'll take the first half
554
+ x = x[:, :self.num_patches, :]
555
+ assert h * w == x.shape[1], f"Expected {h * w} patches, got {x.shape[1]}"
556
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
557
+ x = torch.einsum('nhwpqc->nchpwq', x)
558
+ imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
559
+ return imgs
560
+
561
+ def forward(self, x, c, t, return_act=False):
562
+ act = []
563
+ for block in self.blocks:
564
+ x = block(x, c)
565
+ if return_act:
566
+ act.append(x)
567
+ x = self.final_layer(x, c)
568
+ x = self.unpatchify(x, self.patch_size, self.out_channels)
569
+ if self.learn_sigma:
570
+ mean_pred, log_var_pred = x.chunk(2, dim=1)
571
+ variance_pred = torch.exp(log_var_pred)
572
+ std_dev_pred = torch.sqrt(variance_pred)
573
+ noise = torch.randn_like(mean_pred)
574
+ #uniform_noise = torch.rand_like(mean_pred)
575
+ #uniform_noise = uniform_noise.clamp(min=1e-5, max=1-1e-5)
576
+ #gumbel_noise = -torch.log(-torch.log(uniform_noise))
577
+
578
+ if self.learn_mu==True:
579
+ resampled_x = mean_pred + std_dev_pred * noise
580
+ else:
581
+ resampled_x = std_dev_pred * noise
582
+ x = resampled_x
583
+ else:
584
+ x, _ = x.chunk(2, dim=1)
585
+ if return_act:
586
+ return x, act
587
+ return x
588
+
589
+ def forward_noise(self, x, c):
590
+ for block in self.blocks:
591
+ x = block(x, c)
592
+ x = self.final_layer(x, c)
593
+ x = self.unpatchify(x, self.patch_size, self.out_channels)
594
+ if self.learn_sigma:
595
+ mean_pred, log_var_pred = x.chunk(2, dim=1)
596
+ variance_pred = torch.exp(log_var_pred)
597
+ std_dev_pred = torch.sqrt(variance_pred)
598
+ noise = torch.randn_like(mean_pred)
599
+ if self.learn_mu==True:
600
+ resampled_x = mean_pred + std_dev_pred * noise
601
+ else:
602
+ resampled_x = std_dev_pred * noise
603
+ x = resampled_x
604
+ else:
605
+ x, _ = x.chunk(2, dim=1)
606
+ return x
607
+
608
+ #有两种写法,一种是拿理想的,一种是拿真实的,一种是拼接,一种是加和
609
+ class CombinedModel(nn.Module):
610
+ """
611
+ CombinedModel。
612
+ """
613
+ def __init__(self, sitf1: SiTF1, sitf2: SiTF2):
614
+ super().__init__()
615
+ self.sitf1 = sitf1
616
+ self.sitf2 = sitf2
617
+ input_size=self.sitf1.input_size
618
+ patch_size=self.sitf1.patch_size
619
+ hidden_size=self.sitf1.hidden_size
620
+ self.x_embedder = PatchEmbed(input_size, patch_size, 4, hidden_size, bias=True)
621
+ num_patches = self.x_embedder.num_patches
622
+ # pos_embed needs to support 2*num_patches for concatenated input
623
+ self.pos_embed = nn.Parameter(torch.zeros(1, 2 * num_patches, hidden_size), requires_grad=False)
624
+ # Initialize pos_embed with sin-cos embedding
625
+ pos_embed = get_2d_sincos_pos_embed(hidden_size, int(num_patches ** 0.5))
626
+ # Repeat the pos_embed for both halves (or could use different embeddings)
627
+ pos_embed_full = np.concatenate([pos_embed, pos_embed], axis=0)
628
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed_full).float().unsqueeze(0))
629
+
630
+ def forward(self, x, t, y, return_act=False):
631
+ patch_tokens,x_now = self.sitf1(x, t, y)
632
+ # Interpolate between x_now and x using timestep t: (1-t)*x_now + t*x
633
+ # t shape is (N,), need to broadcast to (N, 1, 1, 1) for broadcasting with image (N, C, H, W)
634
+ t_broadcast = t.view(-1, 1, 1, 1) # (N, 1, 1, 1)
635
+ # Compute interpolated input: (1-t)*x_now + t*x
636
+ x_interpolated = (1 - t_broadcast) * x_now + x
637
+ # Convert interpolated input (image format) back to patch token format (without pos_embed, will add later)
638
+ x_now_patches = self.x_embedder(x_interpolated)
639
+ # Concatenate patch_tokens and x_now_patches along the sequence dimension
640
+ concatenated_input = torch.cat([patch_tokens, x_now_patches], dim=1) # (N, 2*T, D)
641
+ # Add position embedding for the concatenated input
642
+ # Use the same pos_embed for both halves (or could use different embeddings)
643
+ concatenated_input = concatenated_input + self.pos_embed
644
+ t_emb = self.sitf1.t_embedder(t)
645
+ y_emb = self.sitf1.y_embedder(y, self.training)
646
+ c = t_emb + y_emb
647
+ return self.sitf2(concatenated_input, c, t, return_act=return_act)
Rectified_Noise/GVP-Disp/sample_ddp.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This source code is licensed under the license found in the
2
+ # LICENSE file in the root directory of this source tree.
3
+
4
+ """
5
+ Samples a large number of images from a pre-trained SiT model using DDP.
6
+ Subsequently saves a .npz file that can be used to compute FID and other
7
+ evaluation metrics via the ADM repo: https://github.com/openai/guided-diffusion/tree/main/evaluations
8
+
9
+ For a simple single-GPU/CPU sampling script, see sample.py.
10
+ """
11
+ import torch
12
+ import torch.distributed as dist
13
+ from models import SiT_models
14
+ from download import find_model
15
+ from transport import create_transport, Sampler
16
+ from diffusers.models import AutoencoderKL
17
+ from train_utils import parse_ode_args, parse_sde_args, parse_transport_args
18
+ from tqdm import tqdm
19
+ import os
20
+ from PIL import Image
21
+ import numpy as np
22
+ import math
23
+ import argparse
24
+ import sys
25
+
26
+
27
+ def create_npz_from_sample_folder(sample_dir, num=50_000):
28
+ """
29
+ Builds a single .npz file from a folder of .png samples.
30
+ """
31
+ samples = []
32
+ for i in tqdm(range(num), desc="Building .npz file from samples"):
33
+ sample_pil = Image.open(f"{sample_dir}/{i:06d}.png")
34
+ sample_np = np.asarray(sample_pil).astype(np.uint8)
35
+ samples.append(sample_np)
36
+ samples = np.stack(samples)
37
+ assert samples.shape == (num, samples.shape[1], samples.shape[2], 3)
38
+ npz_path = f"{sample_dir}.npz"
39
+ np.savez(npz_path, arr_0=samples)
40
+ print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
41
+ return npz_path
42
+
43
+
44
+ def main(mode, args):
45
+ """
46
+ Run sampling.
47
+ """
48
+ torch.backends.cuda.matmul.allow_tf32 = args.tf32 # True: fast but may lead to some small numerical differences
49
+ assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
50
+ torch.set_grad_enabled(False)
51
+
52
+ # Setup DDP:
53
+ dist.init_process_group("nccl")
54
+ rank = dist.get_rank()
55
+ device = rank % torch.cuda.device_count()
56
+ seed = args.global_seed * dist.get_world_size() + rank
57
+ torch.manual_seed(seed)
58
+ torch.cuda.set_device(device)
59
+ print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
60
+
61
+ if args.ckpt is None:
62
+ assert args.model == "SiT-XL/2", "Only SiT-XL/2 models are available for auto-download."
63
+ assert args.image_size in [256, 512]
64
+ assert args.num_classes == 1000
65
+ assert args.image_size == 256, "512x512 models are not yet available for auto-download." # remove this line when 512x512 models are available
66
+ learn_sigma = args.image_size == 256
67
+ else:
68
+ learn_sigma = False
69
+
70
+ # Load model:
71
+ latent_size = args.image_size // 8
72
+ model = SiT_models[args.model](
73
+ input_size=latent_size,
74
+ num_classes=args.num_classes,
75
+ learn_sigma=learn_sigma,
76
+ ).to(device)
77
+ # Auto-download a pre-trained model or load a custom SiT checkpoint from train.py:
78
+ ckpt_path = args.ckpt or f"SiT-XL-2-{args.image_size}x{args.image_size}.pt"
79
+ state_dict = find_model(ckpt_path)
80
+ model.load_state_dict(state_dict)
81
+ model.eval() # important!
82
+
83
+
84
+ transport = create_transport(
85
+ args.path_type,
86
+ args.prediction,
87
+ args.loss_weight,
88
+ args.train_eps,
89
+ args.sample_eps
90
+ )
91
+ sampler = Sampler(transport)
92
+ if mode == "ODE":
93
+ if args.likelihood:
94
+ assert args.cfg_scale == 1, "Likelihood is incompatible with guidance"
95
+ sample_fn = sampler.sample_ode_likelihood(
96
+ sampling_method=args.sampling_method,
97
+ num_steps=args.num_sampling_steps,
98
+ atol=args.atol,
99
+ rtol=args.rtol,
100
+ )
101
+ else:
102
+ sample_fn = sampler.sample_ode(
103
+ sampling_method=args.sampling_method,
104
+ num_steps=args.num_sampling_steps,
105
+ atol=args.atol,
106
+ rtol=args.rtol,
107
+ reverse=args.reverse
108
+ )
109
+ elif mode == "SDE":
110
+ sample_fn = sampler.sample_sde(
111
+ sampling_method=args.sampling_method,
112
+ diffusion_form=args.diffusion_form,
113
+ diffusion_norm=args.diffusion_norm,
114
+ last_step=args.last_step,
115
+ last_step_size=args.last_step_size,
116
+ num_steps=args.num_sampling_steps,
117
+ )
118
+ vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device)
119
+ assert args.cfg_scale >= 1.0, "In almost all cases, cfg_scale be >= 1.0"
120
+ using_cfg = args.cfg_scale > 1.0
121
+
122
+ # Create folder to save samples:
123
+ model_string_name = args.model.replace("/", "-")
124
+ ckpt_string_name = os.path.basename(args.ckpt).replace(".pt", "") if args.ckpt else "pretrained"
125
+ if mode == "ODE":
126
+ folder_name = f"{model_string_name}-{ckpt_string_name}-" \
127
+ f"cfg-{args.cfg_scale}-{args.per_proc_batch_size}-"\
128
+ f"{mode}-{args.num_sampling_steps}-{args.sampling_method}"
129
+ elif mode == "SDE":
130
+ folder_name = f"{model_string_name}-{ckpt_string_name}-" \
131
+ f"cfg-{args.cfg_scale}-{args.per_proc_batch_size}-"\
132
+ f"{mode}-{args.num_sampling_steps}-{args.sampling_method}-"\
133
+ f"{args.diffusion_form}-{args.last_step}-{args.last_step_size}"
134
+ sample_folder_dir = f"{args.sample_dir}/{folder_name}"
135
+ if rank == 0:
136
+ os.makedirs(sample_folder_dir, exist_ok=True)
137
+ print(f"Saving .png samples at {sample_folder_dir}")
138
+ dist.barrier()
139
+
140
+ # Figure out how many samples we need to generate on each GPU and how many iterations we need to run:
141
+ n = args.per_proc_batch_size
142
+ global_batch_size = n * dist.get_world_size()
143
+ # To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples:
144
+ num_samples = len([name for name in os.listdir(sample_folder_dir) if (os.path.isfile(os.path.join(sample_folder_dir, name)) and ".png" in name)])
145
+ total_samples = int(math.ceil(args.num_fid_samples / global_batch_size) * global_batch_size)
146
+ if rank == 0:
147
+ print(f"Total number of images that will be sampled: {total_samples}")
148
+ assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size"
149
+ samples_needed_this_gpu = int(total_samples // dist.get_world_size())
150
+ assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size"
151
+ iterations = int(samples_needed_this_gpu // n)
152
+ done_iterations = int( int(num_samples // dist.get_world_size()) // n)
153
+ pbar = range(iterations)
154
+ pbar = tqdm(pbar) if rank == 0 else pbar
155
+ total = 0
156
+
157
+ for i in pbar:
158
+ # Sample inputs:
159
+ z = torch.randn(n, model.in_channels, latent_size, latent_size, device=device)
160
+ y = torch.randint(0, args.num_classes, (n,), device=device)
161
+
162
+ # Setup classifier-free guidance:
163
+ if using_cfg:
164
+ z = torch.cat([z, z], 0)
165
+ y_null = torch.tensor([1000] * n, device=device)
166
+ y = torch.cat([y, y_null], 0)
167
+ model_kwargs = dict(y=y, cfg_scale=args.cfg_scale)
168
+ model_fn = model.forward_with_cfg
169
+ else:
170
+ model_kwargs = dict(y=y)
171
+ model_fn = model.forward
172
+
173
+ samples = sample_fn(z, model_fn, **model_kwargs)[-1]
174
+ if using_cfg:
175
+ samples, _ = samples.chunk(2, dim=0) # Remove null class samples
176
+
177
+ samples = vae.decode(samples / 0.18215).sample
178
+ samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
179
+
180
+ # Save samples to disk as individual .png files
181
+ for i, sample in enumerate(samples):
182
+ index = i * dist.get_world_size() + rank + total
183
+ Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png")
184
+ total += global_batch_size
185
+ dist.barrier()
186
+
187
+ # Make sure all processes have finished saving their samples before attempting to convert to .npz
188
+ dist.barrier()
189
+ if rank == 0:
190
+ create_npz_from_sample_folder(sample_folder_dir, args.num_fid_samples)
191
+ print("Done.")
192
+ dist.barrier()
193
+ dist.destroy_process_group()
194
+
195
+
196
+ if __name__ == "__main__":
197
+
198
+ parser = argparse.ArgumentParser()
199
+
200
+ if len(sys.argv) < 2:
201
+ print("Usage: program.py <mode> [options]")
202
+ sys.exit(1)
203
+
204
+ mode = sys.argv[1]
205
+
206
+ assert mode[:2] != "--", "Usage: program.py <mode> [options]"
207
+ assert mode in ["ODE", "SDE"], "Invalid mode. Please choose 'ODE' or 'SDE'"
208
+
209
+ parser.add_argument("--model", type=str, choices=list(SiT_models.keys()), default="SiT-XL/2")
210
+ parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema")
211
+ parser.add_argument("--sample-dir", type=str, default="samples")
212
+ parser.add_argument("--per-proc-batch-size", type=int, default=4)
213
+ parser.add_argument("--num-fid-samples", type=int, default=50_000)
214
+ parser.add_argument("--image-size", type=int, choices=[256, 512], default=256)
215
+ parser.add_argument("--num-classes", type=int, default=1000)
216
+ parser.add_argument("--cfg-scale", type=float, default=1.0)
217
+ parser.add_argument("--num-sampling-steps", type=int, default=250)
218
+ parser.add_argument("--global-seed", type=int, default=0)
219
+ parser.add_argument("--tf32", action=argparse.BooleanOptionalAction, default=True,
220
+ help="By default, use TF32 matmuls. This massively accelerates sampling on Ampere GPUs.")
221
+ parser.add_argument("--ckpt", type=str, default=None,
222
+ help="Optional path to a SiT checkpoint (default: auto-download a pre-trained SiT-XL/2 model).")
223
+
224
+ parse_transport_args(parser)
225
+ if mode == "ODE":
226
+ parse_ode_args(parser)
227
+ # Further processing for ODE
228
+ elif mode == "SDE":
229
+ parse_sde_args(parser)
230
+ # Further processing for SDE
231
+
232
+ args = parser.parse_known_args()[0]
233
+ main(mode, args)
Rectified_Noise/GVP-Disp/sample_rectified_noise.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.distributed as dist
3
+ from torch.nn.parallel import DistributedDataParallel as DDP
4
+ from models import SiT_models
5
+ from download import find_model
6
+ from transport import create_transport, Sampler
7
+ from diffusers.models import AutoencoderKL
8
+ from train_utils import parse_ode_args, parse_sde_args, parse_transport_args
9
+ from tqdm import tqdm
10
+ import os
11
+ from PIL import Image
12
+ import numpy as np
13
+ import math
14
+ import argparse
15
+ import sys
16
+
17
+
18
+ def create_npz_from_sample_folder(sample_dir, num=50_000):
19
+ """
20
+ Builds a single .npz file from a folder of .png samples.
21
+ """
22
+ samples = []
23
+ for i in tqdm(range(num), desc="Building .npz file from samples"):
24
+ sample_pil = Image.open(f"{sample_dir}/{i:06d}.png")
25
+ sample_np = np.asarray(sample_pil).astype(np.uint8)
26
+ samples.append(sample_np)
27
+ samples = np.stack(samples)
28
+ assert samples.shape == (num, samples.shape[1], samples.shape[2], 3)
29
+ npz_path = f"{sample_dir}.npz"
30
+ np.savez(npz_path, arr_0=samples)
31
+ print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
32
+ return npz_path
33
+
34
+
35
+ def fix_state_dict_for_ddp(state_dict):
36
+ """
37
+ Fix state dict keys to match DistributedDataParallel model keys.
38
+ Add "module." prefix to keys if they don't have it.
39
+ """
40
+ # Check if this is a full checkpoint dict with "model", "ema", or "opt" keys
41
+ if isinstance(state_dict, dict) and ("model" in state_dict or "ema" in state_dict or "opt" in state_dict):
42
+ # This is a full checkpoint dict, extract the state dict we need
43
+ # Prefer "ema" then "model" then return as is
44
+ if "ema" in state_dict:
45
+ state_dict = state_dict["ema"]
46
+ elif "model" in state_dict:
47
+ state_dict = state_dict["model"]
48
+ else:
49
+ # If only "opt" or other keys exist, return original
50
+ state_dict = state_dict
51
+
52
+ # Now fix the keys to match DDP format
53
+ fixed_state_dict = {}
54
+ for key, value in state_dict.items():
55
+ if not key.startswith("module."):
56
+ new_key = "module." + key
57
+ else:
58
+ new_key = key
59
+ fixed_state_dict[new_key] = value
60
+ return fixed_state_dict
61
+
62
+ def main(mode, args):
63
+ """
64
+ Run sampling.
65
+ """
66
+ torch.backends.cuda.matmul.allow_tf32 = args.tf32 # True: fast but may lead to some small numerical differences
67
+ assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
68
+ torch.set_grad_enabled(False)
69
+ learn_mu = args.learn_mu
70
+ sitf2_depth = args.depth # Save SiTF2 depth before it gets overwritten
71
+
72
+ # Setup DDP:
73
+ dist.init_process_group("nccl")
74
+ rank = dist.get_rank()
75
+ device = rank % torch.cuda.device_count()
76
+ seed = args.global_seed * dist.get_world_size() + rank
77
+ torch.manual_seed(seed)
78
+ torch.cuda.set_device(device)
79
+ print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
80
+
81
+ if args.ckpt is None:
82
+ assert args.model == "SiT-XL/2", "Only SiT-XL/2 models are available for auto-download."
83
+ assert args.image_size in [256, 512]
84
+ assert args.num_classes == 1000
85
+ assert args.image_size == 256, "512x512 models are not yet available for auto-download." # remove this line when 512x512 models are available
86
+ learn_sigma = args.image_size == 256
87
+ else:
88
+ learn_sigma = False
89
+
90
+ # Load SiTF1 and SiTF2 models and create CombinedModel
91
+ from models import SiTF1, SiTF2, CombinedModel
92
+ latent_size = args.image_size // 8
93
+
94
+ # Get model configuration based on args.model
95
+ model_name = args.model
96
+ if 'XL' in model_name:
97
+ hidden_size, depth, num_heads = 1152, 28, 16
98
+ elif 'L' in model_name:
99
+ hidden_size, depth, num_heads = 1024, 24, 16
100
+ elif 'B' in model_name:
101
+ hidden_size, depth, num_heads = 768, 12, 12
102
+ elif 'S' in model_name:
103
+ hidden_size, depth, num_heads = 384, 12, 6
104
+ else:
105
+ # Default fallback
106
+ hidden_size, depth, num_heads = 768, 12, 12
107
+
108
+ # Extract patch size from model name like 'SiT-XL/2' -> patch_size = 2
109
+ patch_size = int(model_name.split('/')[-1])
110
+
111
+ # Load SiTF1
112
+ sitf1 = SiTF1(
113
+ input_size=latent_size,
114
+ patch_size=patch_size,
115
+ in_channels=4,
116
+ hidden_size=hidden_size,
117
+ depth=depth,
118
+ num_heads=num_heads,
119
+ mlp_ratio=4.0,
120
+ class_dropout_prob=0.1,
121
+ num_classes=args.num_classes,
122
+ learn_sigma=False
123
+ ).to(device)
124
+ sitf1_state_raw = find_model(args.ckpt)
125
+ # find_model now returns ema if available, or the full checkpoint
126
+ # Extract the actual state_dict to use for both sitf1 and base_model
127
+ if isinstance(sitf1_state_raw, dict) and "model" in sitf1_state_raw:
128
+ sitf1_state = sitf1_state_raw["model"]
129
+ else:
130
+ # sitf1_state_raw is already a state_dict (either ema or direct model state)
131
+ sitf1_state = sitf1_state_raw
132
+ sitf1.load_state_dict(sitf1_state)
133
+ sitf1.eval()
134
+
135
+ # For sampling, we can use sitf1 directly instead of creating a separate sit model
136
+ # since sitf1 and sit have the same architecture and weights
137
+
138
+ # Load SiTF2 with the same architecture parameters as SiTF1 for compatibility
139
+ sitf2 = SiTF2(
140
+ input_size=latent_size,
141
+ hidden_size=hidden_size, # Use the same hidden_size as SiTF1
142
+ out_channels=8,
143
+ patch_size=patch_size, # Use the same patch_size as SiTF1
144
+ num_heads=num_heads, # Use the same num_heads as SiTF1
145
+ mlp_ratio=4.0,
146
+ depth=sitf2_depth, # Use the depth specified by command line argument (not the model's default depth)
147
+ learn_sigma=True,
148
+ num_classes=args.num_classes,
149
+ learn_mu=learn_mu
150
+ ).to(device)
151
+ sitf2 = DDP(sitf2, device_ids=[device])
152
+ sitf2_state = find_model(args.sitf2_ckpt)
153
+ # Fix state dict keys to match DDP model
154
+ sitf2_state_fixed = fix_state_dict_for_ddp(sitf2_state)
155
+ try:
156
+ sitf2.load_state_dict(sitf2_state_fixed)
157
+ except Exception as e:
158
+ print(f"Error loading state dict: {e}")
159
+ # Try loading with strict=False as fallback
160
+ sitf2.load_state_dict(sitf2_state_fixed, strict=False)
161
+ sitf2.eval()
162
+ # CombinedModel
163
+
164
+ combined_model = CombinedModel(sitf1, sitf2).to(device)
165
+ sitf2.eval()
166
+ combined_model.eval()
167
+
168
+ # Use SiT_models factory function to create the base model, same as in SiT_clean
169
+ # This ensures correct model configuration
170
+ # Use learn_sigma=False to match sitf1 configuration
171
+ base_model = SiT_models[args.model](
172
+ input_size=latent_size,
173
+ num_classes=args.num_classes,
174
+ learn_sigma=False, # Match sitf1's learn_sigma=False
175
+ ).to(device)
176
+ # Load the checkpoint (same as sitf1) - use the exact same state_dict
177
+ base_model.load_state_dict(sitf1_state)
178
+ base_model.eval()
179
+
180
+ # Determine if CFG will be used (needed for combined_sampling_model function)
181
+ assert args.cfg_scale >= 1.0, "In almost all cases, cfg_scale be >= 1.0"
182
+ using_cfg = args.cfg_scale > 1.0
183
+
184
+ # There are repeated calculations in the middle,
185
+ # which will cause Flops to double. A simplified version will be released later
186
+ def combined_sampling_model(x, t, y=None, **kwargs):
187
+ with torch.no_grad():
188
+ # Handle CFG same as in SiT_clean/sample_ddp.py
189
+ if using_cfg and 'cfg_scale' in kwargs:
190
+ # Use forward_with_cfg when CFG is enabled
191
+ sit_out = base_model.forward_with_cfg(x, t, y, kwargs['cfg_scale'])
192
+ else:
193
+ # Use regular forward when CFG is disabled
194
+ sit_out = base_model.forward(x, t, y)
195
+ # If use_sitf2_before_t05 is True, only use sitf2 when t < threshold
196
+ if args.use_sitf2:
197
+ if args.use_sitf2_before_t05:
198
+ # t is a tensor, check which samples have t < threshold
199
+ # Create a mask: 1.0 where t < threshold, 0.0 otherwise
200
+ mask = (t < args.sitf2_threshold).float()
201
+ # Compute sitf2 output for all samples
202
+ combined_out = combined_model.forward(x, t, y)
203
+ # Expand mask to match the spatial dimensions of combined_out
204
+ # combined_out shape is (batch, channels, height, width)
205
+ while len(mask.shape) < len(combined_out.shape):
206
+ mask = mask.unsqueeze(-1)
207
+ # Broadcast mask to match combined_out shape
208
+ mask = mask.expand_as(combined_out)
209
+ # Only use sitf2 output where t < threshold
210
+ combined_out = combined_out * mask
211
+ # Combine sit_out and masked combined_out
212
+ return sit_out + combined_out
213
+ else:
214
+ # Default behavior: only use base model output
215
+ return sit_out
216
+ else:
217
+ # Default behavior: only use base model output
218
+ return sit_out
219
+
220
+ transport = create_transport(
221
+ args.path_type,
222
+ args.prediction,
223
+ args.loss_weight,
224
+ args.train_eps,
225
+ args.sample_eps
226
+ )
227
+ sampler = Sampler(transport)
228
+ if mode == "ODE":
229
+ if args.likelihood:
230
+ assert args.cfg_scale == 1, "Likelihood is incompatible with guidance"
231
+ sample_fn = sampler.sample_ode_likelihood(
232
+ sampling_method=args.sampling_method,
233
+ num_steps=args.num_sampling_steps,
234
+ atol=args.atol,
235
+ rtol=args.rtol,
236
+ )
237
+ else:
238
+ sample_fn = sampler.sample_ode(
239
+ sampling_method=args.sampling_method,
240
+ num_steps=args.num_sampling_steps,
241
+ atol=args.atol,
242
+ rtol=args.rtol,
243
+ reverse=args.reverse
244
+ )
245
+ elif mode == "SDE":
246
+ sample_fn = sampler.sample_sde(
247
+ sampling_method=args.sampling_method,
248
+ diffusion_form=args.diffusion_form,
249
+ diffusion_norm=args.diffusion_norm,
250
+ last_step=args.last_step,
251
+ last_step_size=args.last_step_size,
252
+ num_steps=args.num_sampling_steps,
253
+ )
254
+ vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device)
255
+
256
+ # Create folder to save samples:
257
+ model_string_name = args.model.replace("/", "-")
258
+ ckpt_string_name = os.path.basename(args.ckpt).replace(".pt", "") if args.ckpt else "pretrained"
259
+ sitf2_ckpt_string_name = os.path.basename(args.sitf2_ckpt).replace(".pt", "") if args.ckpt else "pretrained"
260
+ if mode == "ODE":
261
+ folder_name = f"{sitf2_ckpt_string_name}-{ckpt_string_name}-" \
262
+ f"cfg-{args.cfg_scale}-{args.per_proc_batch_size}-"\
263
+ f"{mode}-{args.num_sampling_steps}-{args.sampling_method}"
264
+ elif mode == "SDE":
265
+ # Add threshold info to folder name if use_sitf2_before_t05 is enabled
266
+ threshold_suffix = f"-threshold-{args.sitf2_threshold}" if args.use_sitf2_before_t05 else ""
267
+ if learn_mu:
268
+ folder_name = f"depth-mu-{sitf2_depth}{threshold_suffix}-{sitf2_ckpt_string_name}-{ckpt_string_name}-" \
269
+ f"cfg-{args.cfg_scale}-{args.per_proc_batch_size}-"\
270
+ f"{mode}-{args.num_sampling_steps}-{args.sampling_method}-"\
271
+ f"{args.diffusion_form}-{args.last_step}-{args.last_step_size}"
272
+ else:
273
+ folder_name = f"depth-sigma-{sitf2_depth}{threshold_suffix}-{sitf2_ckpt_string_name}-{ckpt_string_name}-" \
274
+ f"cfg-{args.cfg_scale}-{args.per_proc_batch_size}-"\
275
+ f"{mode}-{args.num_sampling_steps}-{args.sampling_method}-"\
276
+ f"{args.diffusion_form}-{args.last_step}-{args.last_step_size}"
277
+ sample_folder_dir = f"{args.sample_dir}/{folder_name}"
278
+ if rank == 0:
279
+ os.makedirs(sample_folder_dir, exist_ok=True)
280
+ print(f"Saving .png samples at {sample_folder_dir}")
281
+ dist.barrier()
282
+
283
+ # Figure out how many samples we need to generate on each GPU and how many iterations we need to run:
284
+ n = args.per_proc_batch_size
285
+ global_batch_size = n * dist.get_world_size()
286
+ # To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples:
287
+ num_samples = len([name for name in os.listdir(sample_folder_dir) if (os.path.isfile(os.path.join(sample_folder_dir, name)) and ".png" in name)])
288
+ total_samples = int(math.ceil(args.num_fid_samples / global_batch_size) * global_batch_size)
289
+ if rank == 0:
290
+ print(f"Total number of images that will be sampled: {total_samples}")
291
+ assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size"
292
+ samples_needed_this_gpu = int(total_samples // dist.get_world_size())
293
+ assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size"
294
+ iterations = int(samples_needed_this_gpu // n)
295
+ done_iterations = int( int(num_samples // dist.get_world_size()) // n)
296
+ pbar = range(iterations)
297
+ pbar = tqdm(pbar) if rank == 0 else pbar
298
+ total = 0
299
+
300
+ for i in pbar:
301
+ # Sample inputs:
302
+ z = torch.randn(n, base_model.in_channels, latent_size, latent_size, device=device)
303
+ y = torch.randint(0, args.num_classes, (n,), device=device)
304
+ # Setup classifier-free guidance:
305
+ if using_cfg:
306
+ z = torch.cat([z, z], 0)
307
+ y_null = torch.tensor([1000] * n, device=device)
308
+ y = torch.cat([y, y_null], 0)
309
+ model_kwargs = dict(y=y, cfg_scale=args.cfg_scale)
310
+ else:
311
+ model_kwargs = dict(y=y)
312
+ samples = sample_fn(z, combined_sampling_model, **model_kwargs)[-1]
313
+ if using_cfg:
314
+ samples, _ = samples.chunk(2, dim=0) # Remove null class samples
315
+ samples = vae.decode(samples / 0.18215).sample
316
+ samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
317
+ # Save samples to disk as individual .png files
318
+ for i, sample in enumerate(samples):
319
+ index = i * dist.get_world_size() + rank + total
320
+ Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png")
321
+ total += global_batch_size
322
+ dist.barrier()
323
+
324
+ # Make sure all processes have finished saving their samples before attempting to convert to .npz
325
+ dist.barrier()
326
+ if rank == 0:
327
+ create_npz_from_sample_folder(sample_folder_dir, args.num_fid_samples)
328
+ print("Done.")
329
+ dist.barrier()
330
+ dist.destroy_process_group()
331
+
332
+
333
+ if __name__ == "__main__":
334
+
335
+ parser = argparse.ArgumentParser()
336
+
337
+ if len(sys.argv) < 2:
338
+ print("Usage: program.py <mode> [options]")
339
+ sys.exit(1)
340
+
341
+ mode = sys.argv[1]
342
+
343
+ assert mode[:2] != "--", "Usage: program.py <mode> [options]"
344
+ assert mode in ["ODE", "SDE"], "Invalid mode. Please choose 'ODE' or 'SDE'"
345
+
346
+ parser.add_argument("--model", type=str, choices=list(SiT_models.keys()), default="SiT-XL/2")
347
+ parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema")
348
+ parser.add_argument("--sample-dir", type=str, default="samples")
349
+ parser.add_argument("--per-proc-batch-size", type=int, default=64)
350
+ parser.add_argument("--num-fid-samples", type=int, default=50_000)
351
+ parser.add_argument("--image-size", type=int, choices=[256, 512], default=256)
352
+ parser.add_argument("--num-classes", type=int, default=1000)
353
+ parser.add_argument("--cfg-scale", type=float, default=1.0)
354
+ parser.add_argument("--num-sampling-steps", type=int, default=100)
355
+ parser.add_argument("--global-seed", type=int, default=0)
356
+ parser.add_argument("--tf32", action=argparse.BooleanOptionalAction, default=True,
357
+ help="By default, use TF32 matmuls. This massively accelerates sampling on Ampere GPUs.")
358
+ parser.add_argument("--ckpt", type=str, default=None,
359
+ help="Optional path to a SiT checkpoint.")
360
+ parser.add_argument("--sitf2-ckpt", type=str, required=True, help="Path to SiTF2 checkpoint")
361
+ parser.add_argument("--learn-mu", action=argparse.BooleanOptionalAction, default=True,
362
+ help="Whether to learn mu parameter")
363
+ parser.add_argument("--depth", type=int, default=1,
364
+ help="Depth parameter for SiTF2 model")
365
+ parser.add_argument("--use-sitf2", action=argparse.BooleanOptionalAction, default=True,
366
+ help="Only use SiTF2 output when t < threshold, otherwise use only SiT")
367
+ parser.add_argument("--use-sitf2-before-t05", action=argparse.BooleanOptionalAction, default=False,
368
+ help="Only use SiTF2 output when t < threshold, otherwise use only SiT")
369
+ parser.add_argument("--sitf2-threshold", type=float, default=0.5,
370
+ help="Time threshold for using SiTF2 output (default: 0.5). Only effective when --use-sitf2-before-t05 is True")
371
+ parse_transport_args(parser)
372
+ if mode == "ODE":
373
+ parse_ode_args(parser)
374
+ # Further processing for ODE
375
+ elif mode == "SDE":
376
+ parse_sde_args(parser)
377
+ # Further processing for SDE
378
+
379
+ args = parser.parse_known_args()[0]
380
+ main(mode, args)
Rectified_Noise/GVP-Disp/train_utils.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def none_or_str(value):
2
+ if value == 'None':
3
+ return None
4
+ return value
5
+
6
+ def parse_transport_args(parser):
7
+ group = parser.add_argument_group("Transport arguments")
8
+ group.add_argument("--path-type", type=str, default="GVP", choices=["Linear", "GVP", "VP"],
9
+ help="Type of path for loss calculation. This parameter directly affects the loss form used during training. "
10
+ "Choices: Linear (linear interpolation path), GVP (Geodesic Velocity Path), VP (Velocity Path). "
11
+ "The path_type determines how the transport loss is computed in training_losses().")
12
+ group.add_argument("--prediction", type=str, default="velocity", choices=["velocity", "score", "noise"])
13
+ group.add_argument("--loss-weight", type=none_or_str, default=None, choices=[None, "velocity", "likelihood"])
14
+ group.add_argument("--sample-eps", type=float, default=0.0)
15
+ group.add_argument("--train-eps", type=float, default=0.0)
16
+
17
+ def parse_ode_args(parser):
18
+ group = parser.add_argument_group("ODE arguments")
19
+ group.add_argument("--sampling-method", type=str, default="dopri5", help="blackbox ODE solver methods; for full list check https://github.com/rtqichen/torchdiffeq")
20
+ group.add_argument("--atol", type=float, default=1e-6, help="Absolute tolerance")
21
+ group.add_argument("--rtol", type=float, default=1e-3, help="Relative tolerance")
22
+ group.add_argument("--reverse", action="store_true")
23
+ group.add_argument("--likelihood", action="store_true")
24
+
25
+ def parse_sde_args(parser):
26
+ group = parser.add_argument_group("SDE arguments")
27
+ group.add_argument("--sampling-method", type=str, default="Euler", choices=["Euler", "Heun"])
28
+ group.add_argument("--diffusion-form", type=str, default="sigma", \
29
+ choices=["constant", "SBDM", "sigma", "linear", "decreasing", "increasing-decreasing"],\
30
+ help="form of diffusion coefficient in the SDE")
31
+ group.add_argument("--diffusion-norm", type=float, default=1.0)
32
+ group.add_argument("--last-step", type=none_or_str, default="Mean", choices=[None, "Mean", "Tweedie", "Euler"],\
33
+ help="form of last step taken in the SDE")
34
+ group.add_argument("--last-step-size", type=float, default=0.04, \
35
+ help="size of the last step taken")
Rectified_Noise/GVP-Disp/w_training1_VP.log ADDED
@@ -0,0 +1,628 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ nohup: ignoring input
2
+ [NOTICE] The application is pending for GPU resource in asynchronous queue. The longest waiting time in queue is 1800 seconds.
3
+ Starting rank=0, seed=0, world_size=1.
4
+ [2026-02-01 14:09:25] Experiment directory created at results_256_vp/depth-mu-2-000-SiT-XL-2-VP-velocity-None
5
+ [2026-02-01 14:09:57] Combined_model Parameters: 729,629,632
6
+ [2026-02-01 14:09:57] Total trainable parameters: 53,910,176
7
+ [2026-02-01 14:09:59] Dataset contains 1,281,167 images (/gemini/platform/public/zhaozy/hzh/datasets/Imagenet/train/)
8
+ [2026-02-01 14:09:59] Training for 100000 epochs...
9
+ [2026-02-01 14:09:59] Beginning epoch 0...
10
+ [2026-02-01 14:11:18] (step=0000100) Train Loss: 2.7011, Train Steps/Sec: 1.27
11
+ [2026-02-01 14:12:34] (step=0000200) Train Loss: 1.9056, Train Steps/Sec: 1.32
12
+ [2026-02-01 14:13:49] (step=0000300) Train Loss: 1.7930, Train Steps/Sec: 1.32
13
+ [2026-02-01 14:15:05] (step=0000400) Train Loss: 2.0316, Train Steps/Sec: 1.32
14
+ [2026-02-01 14:16:21] (step=0000500) Train Loss: 1.8412, Train Steps/Sec: 1.32
15
+ [2026-02-01 14:17:37] (step=0000600) Train Loss: 1.8505, Train Steps/Sec: 1.32
16
+ [2026-02-01 14:18:53] (step=0000700) Train Loss: 1.8542, Train Steps/Sec: 1.32
17
+ [2026-02-01 14:20:09] (step=0000800) Train Loss: 1.8904, Train Steps/Sec: 1.32
18
+ [2026-02-01 14:21:25] (step=0000900) Train Loss: 1.9280, Train Steps/Sec: 1.32
19
+ [2026-02-01 14:22:41] (step=0001000) Train Loss: 1.8453, Train Steps/Sec: 1.32
20
+ [2026-02-01 14:23:57] (step=0001100) Train Loss: 1.8745, Train Steps/Sec: 1.32
21
+ [2026-02-01 14:25:13] (step=0001200) Train Loss: 1.8410, Train Steps/Sec: 1.32
22
+ [2026-02-01 14:26:29] (step=0001300) Train Loss: 1.8445, Train Steps/Sec: 1.32
23
+ [2026-02-01 14:27:44] (step=0001400) Train Loss: 1.8173, Train Steps/Sec: 1.32
24
+ [2026-02-01 14:29:00] (step=0001500) Train Loss: 3.5917, Train Steps/Sec: 1.32
25
+ [2026-02-01 14:30:16] (step=0001600) Train Loss: 1.8747, Train Steps/Sec: 1.32
26
+ [2026-02-01 14:31:32] (step=0001700) Train Loss: 1.8092, Train Steps/Sec: 1.32
27
+ [2026-02-01 14:32:48] (step=0001800) Train Loss: 1.8720, Train Steps/Sec: 1.32
28
+ [2026-02-01 14:34:04] (step=0001900) Train Loss: 1.8186, Train Steps/Sec: 1.32
29
+ [2026-02-01 14:35:20] (step=0002000) Train Loss: 1.9034, Train Steps/Sec: 1.32
30
+ [2026-02-01 14:36:36] (step=0002100) Train Loss: 1.8993, Train Steps/Sec: 1.32
31
+ [2026-02-01 14:37:52] (step=0002200) Train Loss: 1.8499, Train Steps/Sec: 1.32
32
+ [2026-02-01 14:39:08] (step=0002300) Train Loss: 2.1165, Train Steps/Sec: 1.32
33
+ [2026-02-01 14:40:24] (step=0002400) Train Loss: 1.8346, Train Steps/Sec: 1.32
34
+ [2026-02-01 14:41:40] (step=0002500) Train Loss: 1.7744, Train Steps/Sec: 1.32
35
+ [2026-02-01 14:42:56] (step=0002600) Train Loss: 1.8164, Train Steps/Sec: 1.32
36
+ [2026-02-01 14:44:12] (step=0002700) Train Loss: 1.8115, Train Steps/Sec: 1.32
37
+ [2026-02-01 14:45:28] (step=0002800) Train Loss: 1.8150, Train Steps/Sec: 1.32
38
+ [2026-02-01 14:46:44] (step=0002900) Train Loss: 1.8270, Train Steps/Sec: 1.32
39
+ [2026-02-01 14:48:00] (step=0003000) Train Loss: 1.9181, Train Steps/Sec: 1.32
40
+ [2026-02-01 14:49:16] (step=0003100) Train Loss: 1.9040, Train Steps/Sec: 1.32
41
+ [2026-02-01 14:50:31] (step=0003200) Train Loss: 2.2287, Train Steps/Sec: 1.32
42
+ [2026-02-01 14:51:47] (step=0003300) Train Loss: 2.0059, Train Steps/Sec: 1.32
43
+ [2026-02-01 14:53:03] (step=0003400) Train Loss: 1.8687, Train Steps/Sec: 1.32
44
+ [2026-02-01 14:54:19] (step=0003500) Train Loss: 1.9185, Train Steps/Sec: 1.32
45
+ [2026-02-01 14:55:35] (step=0003600) Train Loss: 1.9162, Train Steps/Sec: 1.32
46
+ [2026-02-01 14:56:51] (step=0003700) Train Loss: 2.0918, Train Steps/Sec: 1.32
47
+ [2026-02-01 14:58:07] (step=0003800) Train Loss: 2.5750, Train Steps/Sec: 1.32
48
+ [2026-02-01 14:59:23] (step=0003900) Train Loss: 1.8959, Train Steps/Sec: 1.32
49
+ [2026-02-01 15:00:39] (step=0004000) Train Loss: 1.8935, Train Steps/Sec: 1.32
50
+ [2026-02-01 15:01:55] (step=0004100) Train Loss: 1.8143, Train Steps/Sec: 1.32
51
+ [2026-02-01 15:03:11] (step=0004200) Train Loss: 2.0503, Train Steps/Sec: 1.32
52
+ [2026-02-01 15:04:27] (step=0004300) Train Loss: 1.8916, Train Steps/Sec: 1.32
53
+ [2026-02-01 15:05:43] (step=0004400) Train Loss: 2.1279, Train Steps/Sec: 1.32
54
+ [2026-02-01 15:06:59] (step=0004500) Train Loss: 1.8331, Train Steps/Sec: 1.32
55
+ [2026-02-01 15:08:15] (step=0004600) Train Loss: 1.8969, Train Steps/Sec: 1.32
56
+ [2026-02-01 15:09:31] (step=0004700) Train Loss: 1.8220, Train Steps/Sec: 1.32
57
+ [2026-02-01 15:10:47] (step=0004800) Train Loss: 1.8862, Train Steps/Sec: 1.32
58
+ [2026-02-01 15:12:03] (step=0004900) Train Loss: 1.9553, Train Steps/Sec: 1.32
59
+ [2026-02-01 15:13:19] (step=0005000) Train Loss: 1.8549, Train Steps/Sec: 1.31
60
+ [2026-02-01 15:14:35] (step=0005100) Train Loss: 1.9343, Train Steps/Sec: 1.32
61
+ [2026-02-01 15:15:51] (step=0005200) Train Loss: 1.9899, Train Steps/Sec: 1.32
62
+ [2026-02-01 15:17:07] (step=0005300) Train Loss: 1.9115, Train Steps/Sec: 1.32
63
+ [2026-02-01 15:18:23] (step=0005400) Train Loss: 2.2117, Train Steps/Sec: 1.32
64
+ [2026-02-01 15:19:39] (step=0005500) Train Loss: 1.9424, Train Steps/Sec: 1.32
65
+ [2026-02-01 15:20:55] (step=0005600) Train Loss: 1.8367, Train Steps/Sec: 1.32
66
+ [2026-02-01 15:22:11] (step=0005700) Train Loss: 1.8696, Train Steps/Sec: 1.32
67
+ [2026-02-01 15:23:27] (step=0005800) Train Loss: 2.2085, Train Steps/Sec: 1.32
68
+ [2026-02-01 15:24:43] (step=0005900) Train Loss: 1.8185, Train Steps/Sec: 1.32
69
+ [2026-02-01 15:25:59] (step=0006000) Train Loss: 1.8452, Train Steps/Sec: 1.32
70
+ [2026-02-01 15:27:15] (step=0006100) Train Loss: 1.8141, Train Steps/Sec: 1.32
71
+ [2026-02-01 15:28:31] (step=0006200) Train Loss: 2.4398, Train Steps/Sec: 1.32
72
+ [2026-02-01 15:29:47] (step=0006300) Train Loss: 1.9160, Train Steps/Sec: 1.32
73
+ [2026-02-01 15:31:03] (step=0006400) Train Loss: 1.9920, Train Steps/Sec: 1.32
74
+ [2026-02-01 15:32:19] (step=0006500) Train Loss: 1.8726, Train Steps/Sec: 1.32
75
+ [2026-02-01 15:33:35] (step=0006600) Train Loss: 1.9302, Train Steps/Sec: 1.32
76
+ [2026-02-01 15:34:51] (step=0006700) Train Loss: 1.8886, Train Steps/Sec: 1.32
77
+ [2026-02-01 15:36:07] (step=0006800) Train Loss: 1.8492, Train Steps/Sec: 1.32
78
+ [2026-02-01 15:37:23] (step=0006900) Train Loss: 2.0008, Train Steps/Sec: 1.32
79
+ [2026-02-01 15:38:39] (step=0007000) Train Loss: 1.9791, Train Steps/Sec: 1.32
80
+ [2026-02-01 15:39:55] (step=0007100) Train Loss: 1.9221, Train Steps/Sec: 1.32
81
+ [2026-02-01 15:41:11] (step=0007200) Train Loss: 1.8893, Train Steps/Sec: 1.32
82
+ [2026-02-01 15:42:27] (step=0007300) Train Loss: 1.8739, Train Steps/Sec: 1.32
83
+ [2026-02-01 15:43:43] (step=0007400) Train Loss: 2.6370, Train Steps/Sec: 1.32
84
+ [2026-02-01 15:44:59] (step=0007500) Train Loss: 2.1859, Train Steps/Sec: 1.32
85
+ [2026-02-01 15:46:15] (step=0007600) Train Loss: 1.8067, Train Steps/Sec: 1.32
86
+ [2026-02-01 15:47:31] (step=0007700) Train Loss: 1.8996, Train Steps/Sec: 1.32
87
+ [2026-02-01 15:48:47] (step=0007800) Train Loss: 1.9468, Train Steps/Sec: 1.32
88
+ [2026-02-01 15:50:03] (step=0007900) Train Loss: 1.8925, Train Steps/Sec: 1.32
89
+ [2026-02-01 15:51:19] (step=0008000) Train Loss: 1.7844, Train Steps/Sec: 1.32
90
+ [2026-02-01 15:52:35] (step=0008100) Train Loss: 1.9823, Train Steps/Sec: 1.32
91
+ [2026-02-01 15:53:51] (step=0008200) Train Loss: 1.9363, Train Steps/Sec: 1.32
92
+ [2026-02-01 15:55:07] (step=0008300) Train Loss: 1.8508, Train Steps/Sec: 1.32
93
+ [2026-02-01 15:56:22] (step=0008400) Train Loss: 1.9048, Train Steps/Sec: 1.32
94
+ [2026-02-01 15:57:38] (step=0008500) Train Loss: 1.8955, Train Steps/Sec: 1.32
95
+ [2026-02-01 15:58:54] (step=0008600) Train Loss: 1.8585, Train Steps/Sec: 1.32
96
+ [2026-02-01 16:00:10] (step=0008700) Train Loss: 1.8621, Train Steps/Sec: 1.32
97
+ [2026-02-01 16:01:26] (step=0008800) Train Loss: 1.8826, Train Steps/Sec: 1.32
98
+ [2026-02-01 16:02:43] (step=0008900) Train Loss: 1.9289, Train Steps/Sec: 1.31
99
+ [2026-02-01 16:03:59] (step=0009000) Train Loss: 1.9667, Train Steps/Sec: 1.32
100
+ [2026-02-01 16:05:15] (step=0009100) Train Loss: 2.1871, Train Steps/Sec: 1.32
101
+ [2026-02-01 16:06:31] (step=0009200) Train Loss: 1.8651, Train Steps/Sec: 1.32
102
+ [2026-02-01 16:07:47] (step=0009300) Train Loss: 1.9620, Train Steps/Sec: 1.32
103
+ [2026-02-01 16:09:03] (step=0009400) Train Loss: 1.8992, Train Steps/Sec: 1.32
104
+ [2026-02-01 16:10:18] (step=0009500) Train Loss: 1.8620, Train Steps/Sec: 1.32
105
+ [2026-02-01 16:11:34] (step=0009600) Train Loss: 1.9782, Train Steps/Sec: 1.32
106
+ [2026-02-01 16:12:50] (step=0009700) Train Loss: 2.3364, Train Steps/Sec: 1.32
107
+ [2026-02-01 16:14:06] (step=0009800) Train Loss: 1.8309, Train Steps/Sec: 1.32
108
+ [2026-02-01 16:15:22] (step=0009900) Train Loss: 2.5777, Train Steps/Sec: 1.32
109
+ [2026-02-01 16:16:38] (step=0010000) Train Loss: 1.9410, Train Steps/Sec: 1.32
110
+ [2026-02-01 16:16:45] Beginning epoch 1...
111
+ [2026-02-01 16:17:56] (step=0010100) Train Loss: 1.8156, Train Steps/Sec: 1.28
112
+ [2026-02-01 16:19:12] (step=0010200) Train Loss: 1.7965, Train Steps/Sec: 1.32
113
+ [2026-02-01 16:20:28] (step=0010300) Train Loss: 1.9732, Train Steps/Sec: 1.32
114
+ [2026-02-01 16:21:44] (step=0010400) Train Loss: 2.6702, Train Steps/Sec: 1.32
115
+ [2026-02-01 16:23:00] (step=0010500) Train Loss: 1.9175, Train Steps/Sec: 1.32
116
+ [2026-02-01 16:24:16] (step=0010600) Train Loss: 1.8493, Train Steps/Sec: 1.32
117
+ [2026-02-01 16:25:32] (step=0010700) Train Loss: 1.8514, Train Steps/Sec: 1.32
118
+ [2026-02-01 16:26:48] (step=0010800) Train Loss: 2.0059, Train Steps/Sec: 1.32
119
+ [2026-02-01 16:28:04] (step=0010900) Train Loss: 1.8519, Train Steps/Sec: 1.32
120
+ [2026-02-01 16:29:20] (step=0011000) Train Loss: 1.8523, Train Steps/Sec: 1.32
121
+ [2026-02-01 16:30:36] (step=0011100) Train Loss: 1.7980, Train Steps/Sec: 1.32
122
+ [2026-02-01 16:31:52] (step=0011200) Train Loss: 1.8429, Train Steps/Sec: 1.32
123
+ [2026-02-01 16:33:08] (step=0011300) Train Loss: 1.9200, Train Steps/Sec: 1.32
124
+ [2026-02-01 16:34:24] (step=0011400) Train Loss: 1.8371, Train Steps/Sec: 1.32
125
+ [2026-02-01 16:35:40] (step=0011500) Train Loss: 2.0173, Train Steps/Sec: 1.32
126
+ [2026-02-01 16:36:56] (step=0011600) Train Loss: 1.8135, Train Steps/Sec: 1.32
127
+ [2026-02-01 16:38:12] (step=0011700) Train Loss: 1.9532, Train Steps/Sec: 1.32
128
+ [2026-02-01 16:39:28] (step=0011800) Train Loss: 2.0043, Train Steps/Sec: 1.32
129
+ [2026-02-01 16:40:44] (step=0011900) Train Loss: 1.8474, Train Steps/Sec: 1.32
130
+ [2026-02-01 16:42:00] (step=0012000) Train Loss: 1.8364, Train Steps/Sec: 1.32
131
+ [2026-02-01 16:43:15] (step=0012100) Train Loss: 2.6696, Train Steps/Sec: 1.32
132
+ [2026-02-01 16:44:31] (step=0012200) Train Loss: 1.8652, Train Steps/Sec: 1.32
133
+ [2026-02-01 16:45:47] (step=0012300) Train Loss: 1.9174, Train Steps/Sec: 1.32
134
+ [2026-02-01 16:47:03] (step=0012400) Train Loss: 1.8479, Train Steps/Sec: 1.31
135
+ [2026-02-01 16:48:19] (step=0012500) Train Loss: 1.8228, Train Steps/Sec: 1.32
136
+ [2026-02-01 16:49:35] (step=0012600) Train Loss: 1.9067, Train Steps/Sec: 1.32
137
+ [2026-02-01 16:50:51] (step=0012700) Train Loss: 1.7572, Train Steps/Sec: 1.32
138
+ [2026-02-01 16:52:07] (step=0012800) Train Loss: 1.8446, Train Steps/Sec: 1.32
139
+ [2026-02-01 16:53:23] (step=0012900) Train Loss: 1.8543, Train Steps/Sec: 1.32
140
+ [2026-02-01 16:54:39] (step=0013000) Train Loss: 1.8222, Train Steps/Sec: 1.32
141
+ [2026-02-01 16:55:55] (step=0013100) Train Loss: 2.0108, Train Steps/Sec: 1.32
142
+ [2026-02-01 16:57:11] (step=0013200) Train Loss: 2.3761, Train Steps/Sec: 1.32
143
+ [2026-02-01 16:58:27] (step=0013300) Train Loss: 1.8902, Train Steps/Sec: 1.32
144
+ [2026-02-01 16:59:43] (step=0013400) Train Loss: 1.8800, Train Steps/Sec: 1.32
145
+ [2026-02-01 17:00:59] (step=0013500) Train Loss: 1.7917, Train Steps/Sec: 1.32
146
+ [2026-02-01 17:02:15] (step=0013600) Train Loss: 1.9730, Train Steps/Sec: 1.32
147
+ [2026-02-01 17:03:31] (step=0013700) Train Loss: 1.8894, Train Steps/Sec: 1.32
148
+ [2026-02-01 17:04:47] (step=0013800) Train Loss: 2.1075, Train Steps/Sec: 1.32
149
+ [2026-02-01 17:06:03] (step=0013900) Train Loss: 1.8469, Train Steps/Sec: 1.32
150
+ [2026-02-01 17:07:19] (step=0014000) Train Loss: 1.8705, Train Steps/Sec: 1.32
151
+ [2026-02-01 17:08:35] (step=0014100) Train Loss: 1.8630, Train Steps/Sec: 1.32
152
+ [2026-02-01 17:09:51] (step=0014200) Train Loss: 1.8509, Train Steps/Sec: 1.32
153
+ [2026-02-01 17:11:07] (step=0014300) Train Loss: 2.2249, Train Steps/Sec: 1.32
154
+ [2026-02-01 17:12:23] (step=0014400) Train Loss: 1.8378, Train Steps/Sec: 1.32
155
+ [2026-02-01 17:13:39] (step=0014500) Train Loss: 1.8106, Train Steps/Sec: 1.32
156
+ [2026-02-01 17:14:55] (step=0014600) Train Loss: 1.8131, Train Steps/Sec: 1.32
157
+ [2026-02-01 17:16:11] (step=0014700) Train Loss: 1.9024, Train Steps/Sec: 1.32
158
+ [2026-02-01 17:17:27] (step=0014800) Train Loss: 2.2030, Train Steps/Sec: 1.32
159
+ [2026-02-01 17:18:42] (step=0014900) Train Loss: nan, Train Steps/Sec: 1.33
160
+ [2026-02-01 17:19:58] (step=0015000) Train Loss: nan, Train Steps/Sec: 1.33
161
+ [2026-02-01 17:21:13] (step=0015100) Train Loss: nan, Train Steps/Sec: 1.33
162
+ [2026-02-01 17:22:28] (step=0015200) Train Loss: nan, Train Steps/Sec: 1.33
163
+ [2026-02-01 17:23:43] (step=0015300) Train Loss: nan, Train Steps/Sec: 1.33
164
+ [2026-02-01 17:24:58] (step=0015400) Train Loss: nan, Train Steps/Sec: 1.33
165
+ [2026-02-01 17:26:13] (step=0015500) Train Loss: nan, Train Steps/Sec: 1.33
166
+ [2026-02-01 17:27:29] (step=0015600) Train Loss: nan, Train Steps/Sec: 1.33
167
+ [2026-02-01 17:28:44] (step=0015700) Train Loss: nan, Train Steps/Sec: 1.33
168
+ [2026-02-01 17:29:59] (step=0015800) Train Loss: nan, Train Steps/Sec: 1.33
169
+ [2026-02-01 17:31:14] (step=0015900) Train Loss: nan, Train Steps/Sec: 1.33
170
+ [2026-02-01 17:32:29] (step=0016000) Train Loss: nan, Train Steps/Sec: 1.33
171
+ [2026-02-01 17:33:44] (step=0016100) Train Loss: nan, Train Steps/Sec: 1.33
172
+ [2026-02-01 17:34:59] (step=0016200) Train Loss: nan, Train Steps/Sec: 1.33
173
+ [2026-02-01 17:36:14] (step=0016300) Train Loss: nan, Train Steps/Sec: 1.33
174
+ [2026-02-01 17:37:29] (step=0016400) Train Loss: nan, Train Steps/Sec: 1.33
175
+ [2026-02-01 17:38:45] (step=0016500) Train Loss: nan, Train Steps/Sec: 1.33
176
+ [2026-02-01 17:40:00] (step=0016600) Train Loss: nan, Train Steps/Sec: 1.33
177
+ [2026-02-01 17:41:15] (step=0016700) Train Loss: nan, Train Steps/Sec: 1.33
178
+ [2026-02-01 17:42:30] (step=0016800) Train Loss: nan, Train Steps/Sec: 1.33
179
+ [2026-02-01 17:43:45] (step=0016900) Train Loss: nan, Train Steps/Sec: 1.33
180
+ [2026-02-01 17:45:00] (step=0017000) Train Loss: nan, Train Steps/Sec: 1.33
181
+ [2026-02-01 17:46:16] (step=0017100) Train Loss: nan, Train Steps/Sec: 1.33
182
+ [2026-02-01 17:47:31] (step=0017200) Train Loss: nan, Train Steps/Sec: 1.33
183
+ [2026-02-01 17:48:46] (step=0017300) Train Loss: nan, Train Steps/Sec: 1.33
184
+ [2026-02-01 17:50:01] (step=0017400) Train Loss: nan, Train Steps/Sec: 1.33
185
+ [2026-02-01 17:51:16] (step=0017500) Train Loss: nan, Train Steps/Sec: 1.33
186
+ [2026-02-01 17:52:31] (step=0017600) Train Loss: nan, Train Steps/Sec: 1.33
187
+ [2026-02-01 17:53:46] (step=0017700) Train Loss: nan, Train Steps/Sec: 1.33
188
+ [2026-02-01 17:55:01] (step=0017800) Train Loss: nan, Train Steps/Sec: 1.33
189
+ [2026-02-01 17:56:16] (step=0017900) Train Loss: nan, Train Steps/Sec: 1.33
190
+ [2026-02-01 17:57:31] (step=0018000) Train Loss: nan, Train Steps/Sec: 1.33
191
+ [2026-02-01 17:58:46] (step=0018100) Train Loss: nan, Train Steps/Sec: 1.33
192
+ [2026-02-01 18:00:02] (step=0018200) Train Loss: nan, Train Steps/Sec: 1.33
193
+ [2026-02-01 18:01:17] (step=0018300) Train Loss: nan, Train Steps/Sec: 1.33
194
+ [2026-02-01 18:02:32] (step=0018400) Train Loss: nan, Train Steps/Sec: 1.33
195
+ [2026-02-01 18:03:47] (step=0018500) Train Loss: nan, Train Steps/Sec: 1.33
196
+ [2026-02-01 18:05:02] (step=0018600) Train Loss: nan, Train Steps/Sec: 1.33
197
+ [2026-02-01 18:06:17] (step=0018700) Train Loss: nan, Train Steps/Sec: 1.33
198
+ [2026-02-01 18:07:32] (step=0018800) Train Loss: nan, Train Steps/Sec: 1.33
199
+ [2026-02-01 18:08:47] (step=0018900) Train Loss: nan, Train Steps/Sec: 1.33
200
+ [2026-02-01 18:10:02] (step=0019000) Train Loss: nan, Train Steps/Sec: 1.33
201
+ [2026-02-01 18:11:17] (step=0019100) Train Loss: nan, Train Steps/Sec: 1.33
202
+ [2026-02-01 18:12:33] (step=0019200) Train Loss: nan, Train Steps/Sec: 1.33
203
+ [2026-02-01 18:13:48] (step=0019300) Train Loss: nan, Train Steps/Sec: 1.33
204
+ [2026-02-01 18:15:03] (step=0019400) Train Loss: nan, Train Steps/Sec: 1.33
205
+ [2026-02-01 18:16:18] (step=0019500) Train Loss: nan, Train Steps/Sec: 1.33
206
+ [2026-02-01 18:17:33] (step=0019600) Train Loss: nan, Train Steps/Sec: 1.33
207
+ [2026-02-01 18:18:48] (step=0019700) Train Loss: nan, Train Steps/Sec: 1.33
208
+ [2026-02-01 18:20:03] (step=0019800) Train Loss: nan, Train Steps/Sec: 1.33
209
+ [2026-02-01 18:21:18] (step=0019900) Train Loss: nan, Train Steps/Sec: 1.33
210
+ [2026-02-01 18:22:34] (step=0020000) Train Loss: nan, Train Steps/Sec: 1.33
211
+ [2026-02-01 18:22:47] Beginning epoch 2...
212
+ [2026-02-01 18:23:51] (step=0020100) Train Loss: nan, Train Steps/Sec: 1.29
213
+ [2026-02-01 18:25:06] (step=0020200) Train Loss: nan, Train Steps/Sec: 1.33
214
+ [2026-02-01 18:26:21] (step=0020300) Train Loss: nan, Train Steps/Sec: 1.33
215
+ [2026-02-01 18:27:36] (step=0020400) Train Loss: nan, Train Steps/Sec: 1.33
216
+ [2026-02-01 18:28:51] (step=0020500) Train Loss: nan, Train Steps/Sec: 1.33
217
+ [2026-02-01 18:30:06] (step=0020600) Train Loss: nan, Train Steps/Sec: 1.33
218
+ [2026-02-01 18:31:21] (step=0020700) Train Loss: nan, Train Steps/Sec: 1.33
219
+ [2026-02-01 18:32:37] (step=0020800) Train Loss: nan, Train Steps/Sec: 1.33
220
+ [2026-02-01 18:33:52] (step=0020900) Train Loss: nan, Train Steps/Sec: 1.33
221
+ [2026-02-01 18:35:07] (step=0021000) Train Loss: nan, Train Steps/Sec: 1.33
222
+ [2026-02-01 18:36:22] (step=0021100) Train Loss: nan, Train Steps/Sec: 1.33
223
+ [2026-02-01 18:37:37] (step=0021200) Train Loss: nan, Train Steps/Sec: 1.33
224
+ [2026-02-01 18:38:52] (step=0021300) Train Loss: nan, Train Steps/Sec: 1.33
225
+ [2026-02-01 18:40:07] (step=0021400) Train Loss: nan, Train Steps/Sec: 1.33
226
+ [2026-02-01 18:41:22] (step=0021500) Train Loss: nan, Train Steps/Sec: 1.33
227
+ [2026-02-01 18:42:37] (step=0021600) Train Loss: nan, Train Steps/Sec: 1.33
228
+ [2026-02-01 18:43:53] (step=0021700) Train Loss: nan, Train Steps/Sec: 1.33
229
+ [2026-02-01 18:45:08] (step=0021800) Train Loss: nan, Train Steps/Sec: 1.33
230
+ [2026-02-01 18:46:23] (step=0021900) Train Loss: nan, Train Steps/Sec: 1.33
231
+ [2026-02-01 18:47:38] (step=0022000) Train Loss: nan, Train Steps/Sec: 1.33
232
+ [2026-02-01 18:48:53] (step=0022100) Train Loss: nan, Train Steps/Sec: 1.33
233
+ [2026-02-01 18:50:08] (step=0022200) Train Loss: nan, Train Steps/Sec: 1.33
234
+ [2026-02-01 18:51:24] (step=0022300) Train Loss: nan, Train Steps/Sec: 1.33
235
+ [2026-02-01 18:52:39] (step=0022400) Train Loss: nan, Train Steps/Sec: 1.33
236
+ [2026-02-01 18:53:54] (step=0022500) Train Loss: nan, Train Steps/Sec: 1.33
237
+ [2026-02-01 18:55:09] (step=0022600) Train Loss: nan, Train Steps/Sec: 1.33
238
+ [2026-02-01 18:56:24] (step=0022700) Train Loss: nan, Train Steps/Sec: 1.33
239
+ [2026-02-01 18:57:39] (step=0022800) Train Loss: nan, Train Steps/Sec: 1.33
240
+ [2026-02-01 18:58:54] (step=0022900) Train Loss: nan, Train Steps/Sec: 1.33
241
+ [2026-02-01 19:00:09] (step=0023000) Train Loss: nan, Train Steps/Sec: 1.33
242
+ [2026-02-01 19:01:24] (step=0023100) Train Loss: nan, Train Steps/Sec: 1.33
243
+ [2026-02-01 19:02:39] (step=0023200) Train Loss: nan, Train Steps/Sec: 1.33
244
+ [2026-02-01 19:03:54] (step=0023300) Train Loss: nan, Train Steps/Sec: 1.33
245
+ [2026-02-01 19:05:09] (step=0023400) Train Loss: nan, Train Steps/Sec: 1.33
246
+ [2026-02-01 19:06:25] (step=0023500) Train Loss: nan, Train Steps/Sec: 1.33
247
+ [2026-02-01 19:07:40] (step=0023600) Train Loss: nan, Train Steps/Sec: 1.33
248
+ [2026-02-01 19:08:55] (step=0023700) Train Loss: nan, Train Steps/Sec: 1.33
249
+ [2026-02-01 19:10:10] (step=0023800) Train Loss: nan, Train Steps/Sec: 1.33
250
+ [2026-02-01 19:11:25] (step=0023900) Train Loss: nan, Train Steps/Sec: 1.33
251
+ [2026-02-01 19:12:40] (step=0024000) Train Loss: nan, Train Steps/Sec: 1.33
252
+ [2026-02-01 19:13:56] (step=0024100) Train Loss: nan, Train Steps/Sec: 1.33
253
+ [2026-02-01 19:15:11] (step=0024200) Train Loss: nan, Train Steps/Sec: 1.33
254
+ [2026-02-01 19:16:26] (step=0024300) Train Loss: nan, Train Steps/Sec: 1.33
255
+ [2026-02-01 19:17:41] (step=0024400) Train Loss: nan, Train Steps/Sec: 1.33
256
+ [2026-02-01 19:18:56] (step=0024500) Train Loss: nan, Train Steps/Sec: 1.33
257
+ [2026-02-01 19:20:11] (step=0024600) Train Loss: nan, Train Steps/Sec: 1.33
258
+ [2026-02-01 19:21:26] (step=0024700) Train Loss: nan, Train Steps/Sec: 1.33
259
+ [2026-02-01 19:22:41] (step=0024800) Train Loss: nan, Train Steps/Sec: 1.33
260
+ [2026-02-01 19:23:56] (step=0024900) Train Loss: nan, Train Steps/Sec: 1.33
261
+ [2026-02-01 19:25:11] (step=0025000) Train Loss: nan, Train Steps/Sec: 1.33
262
+ 25000
263
+ [2026-02-01 19:25:12] Saved checkpoint to results_256_vp/depth-mu-2-000-SiT-XL-2-VP-velocity-None/checkpoints/0025000.pt
264
+ [2026-02-01 19:26:27] (step=0025100) Train Loss: nan, Train Steps/Sec: 1.32
265
+ [2026-02-01 19:27:36] Generating EMA samples...
266
+ [2026-02-01 19:27:42] (step=0025200) Train Loss: nan, Train Steps/Sec: 1.33
267
+ [2026-02-01 19:28:57] (step=0025300) Train Loss: nan, Train Steps/Sec: 1.33
268
+ [2026-02-01 19:30:13] (step=0025400) Train Loss: nan, Train Steps/Sec: 1.33
269
+ [2026-02-01 19:31:28] (step=0025500) Train Loss: nan, Train Steps/Sec: 1.33
270
+ [2026-02-01 19:32:43] (step=0025600) Train Loss: nan, Train Steps/Sec: 1.33
271
+ [2026-02-01 19:33:58] (step=0025700) Train Loss: nan, Train Steps/Sec: 1.33
272
+ [2026-02-01 19:35:13] (step=0025800) Train Loss: nan, Train Steps/Sec: 1.33
273
+ [2026-02-01 19:36:28] (step=0025900) Train Loss: nan, Train Steps/Sec: 1.33
274
+ [2026-02-01 19:37:44] (step=0026000) Train Loss: nan, Train Steps/Sec: 1.33
275
+ [2026-02-01 19:38:59] (step=0026100) Train Loss: nan, Train Steps/Sec: 1.33
276
+ [2026-02-01 19:40:14] (step=0026200) Train Loss: nan, Train Steps/Sec: 1.33
277
+ [2026-02-01 19:41:29] (step=0026300) Train Loss: nan, Train Steps/Sec: 1.33
278
+ [2026-02-01 19:42:44] (step=0026400) Train Loss: nan, Train Steps/Sec: 1.33
279
+ [2026-02-01 19:43:59] (step=0026500) Train Loss: nan, Train Steps/Sec: 1.33
280
+ [2026-02-01 19:45:14] (step=0026600) Train Loss: nan, Train Steps/Sec: 1.33
281
+ [2026-02-01 19:46:29] (step=0026700) Train Loss: nan, Train Steps/Sec: 1.33
282
+ [2026-02-01 19:47:44] (step=0026800) Train Loss: nan, Train Steps/Sec: 1.33
283
+ [2026-02-01 19:49:00] (step=0026900) Train Loss: nan, Train Steps/Sec: 1.33
284
+ [2026-02-01 19:50:15] (step=0027000) Train Loss: nan, Train Steps/Sec: 1.33
285
+ [2026-02-01 19:51:30] (step=0027100) Train Loss: nan, Train Steps/Sec: 1.33
286
+ [2026-02-01 19:52:45] (step=0027200) Train Loss: nan, Train Steps/Sec: 1.33
287
+ [2026-02-01 19:54:00] (step=0027300) Train Loss: nan, Train Steps/Sec: 1.33
288
+ [2026-02-01 19:55:15] (step=0027400) Train Loss: nan, Train Steps/Sec: 1.33
289
+ [2026-02-01 19:56:30] (step=0027500) Train Loss: nan, Train Steps/Sec: 1.33
290
+ [2026-02-01 19:57:45] (step=0027600) Train Loss: nan, Train Steps/Sec: 1.33
291
+ [2026-02-01 19:59:00] (step=0027700) Train Loss: nan, Train Steps/Sec: 1.33
292
+ [2026-02-01 20:00:15] (step=0027800) Train Loss: nan, Train Steps/Sec: 1.33
293
+ [2026-02-01 20:01:31] (step=0027900) Train Loss: nan, Train Steps/Sec: 1.33
294
+ [2026-02-01 20:02:46] (step=0028000) Train Loss: nan, Train Steps/Sec: 1.33
295
+ [2026-02-01 20:04:01] (step=0028100) Train Loss: nan, Train Steps/Sec: 1.33
296
+ [2026-02-01 20:05:16] (step=0028200) Train Loss: nan, Train Steps/Sec: 1.33
297
+ [2026-02-01 20:06:31] (step=0028300) Train Loss: nan, Train Steps/Sec: 1.33
298
+ [2026-02-01 20:07:46] (step=0028400) Train Loss: nan, Train Steps/Sec: 1.33
299
+ [2026-02-01 20:09:02] (step=0028500) Train Loss: nan, Train Steps/Sec: 1.33
300
+ [2026-02-01 20:10:17] (step=0028600) Train Loss: nan, Train Steps/Sec: 1.33
301
+ [2026-02-01 20:11:32] (step=0028700) Train Loss: nan, Train Steps/Sec: 1.33
302
+ [2026-02-01 20:12:47] (step=0028800) Train Loss: nan, Train Steps/Sec: 1.33
303
+ [2026-02-01 20:14:02] (step=0028900) Train Loss: nan, Train Steps/Sec: 1.33
304
+ [2026-02-01 20:15:17] (step=0029000) Train Loss: nan, Train Steps/Sec: 1.33
305
+ [2026-02-01 20:16:32] (step=0029100) Train Loss: nan, Train Steps/Sec: 1.33
306
+ [2026-02-01 20:17:47] (step=0029200) Train Loss: nan, Train Steps/Sec: 1.33
307
+ [2026-02-01 20:19:02] (step=0029300) Train Loss: nan, Train Steps/Sec: 1.33
308
+ [2026-02-01 20:20:18] (step=0029400) Train Loss: nan, Train Steps/Sec: 1.33
309
+ [2026-02-01 20:21:33] (step=0029500) Train Loss: nan, Train Steps/Sec: 1.33
310
+ [2026-02-01 20:22:48] (step=0029600) Train Loss: nan, Train Steps/Sec: 1.33
311
+ [2026-02-01 20:24:03] (step=0029700) Train Loss: nan, Train Steps/Sec: 1.33
312
+ [2026-02-01 20:25:18] (step=0029800) Train Loss: nan, Train Steps/Sec: 1.33
313
+ [2026-02-01 20:26:33] (step=0029900) Train Loss: nan, Train Steps/Sec: 1.33
314
+ [2026-02-01 20:27:49] (step=0030000) Train Loss: nan, Train Steps/Sec: 1.33
315
+ [2026-02-01 20:28:09] Beginning epoch 3...
316
+ [2026-02-01 20:29:06] (step=0030100) Train Loss: nan, Train Steps/Sec: 1.29
317
+ [2026-02-01 20:30:21] (step=0030200) Train Loss: nan, Train Steps/Sec: 1.33
318
+ [2026-02-01 20:31:36] (step=0030300) Train Loss: nan, Train Steps/Sec: 1.33
319
+ [2026-02-01 20:32:51] (step=0030400) Train Loss: nan, Train Steps/Sec: 1.33
320
+ [2026-02-01 20:34:06] (step=0030500) Train Loss: nan, Train Steps/Sec: 1.33
321
+ [2026-02-01 20:35:22] (step=0030600) Train Loss: nan, Train Steps/Sec: 1.33
322
+ [2026-02-01 20:36:37] (step=0030700) Train Loss: nan, Train Steps/Sec: 1.33
323
+ [2026-02-01 20:37:52] (step=0030800) Train Loss: nan, Train Steps/Sec: 1.33
324
+ [2026-02-01 20:39:07] (step=0030900) Train Loss: nan, Train Steps/Sec: 1.33
325
+ [2026-02-01 20:40:22] (step=0031000) Train Loss: nan, Train Steps/Sec: 1.33
326
+ [2026-02-01 20:41:37] (step=0031100) Train Loss: nan, Train Steps/Sec: 1.33
327
+ [2026-02-01 20:42:52] (step=0031200) Train Loss: nan, Train Steps/Sec: 1.33
328
+ [2026-02-01 20:44:08] (step=0031300) Train Loss: nan, Train Steps/Sec: 1.33
329
+ [2026-02-01 20:45:23] (step=0031400) Train Loss: nan, Train Steps/Sec: 1.33
330
+ [2026-02-01 20:46:38] (step=0031500) Train Loss: nan, Train Steps/Sec: 1.33
331
+ [2026-02-01 20:47:53] (step=0031600) Train Loss: nan, Train Steps/Sec: 1.33
332
+ [2026-02-01 20:49:08] (step=0031700) Train Loss: nan, Train Steps/Sec: 1.33
333
+ [2026-02-01 20:50:23] (step=0031800) Train Loss: nan, Train Steps/Sec: 1.33
334
+ [2026-02-01 20:51:38] (step=0031900) Train Loss: nan, Train Steps/Sec: 1.33
335
+ [2026-02-01 20:52:54] (step=0032000) Train Loss: nan, Train Steps/Sec: 1.33
336
+ [2026-02-01 20:54:09] (step=0032100) Train Loss: nan, Train Steps/Sec: 1.33
337
+ [2026-02-01 20:55:24] (step=0032200) Train Loss: nan, Train Steps/Sec: 1.33
338
+ [2026-02-01 20:56:39] (step=0032300) Train Loss: nan, Train Steps/Sec: 1.33
339
+ [2026-02-01 20:57:54] (step=0032400) Train Loss: nan, Train Steps/Sec: 1.33
340
+ [2026-02-01 20:59:09] (step=0032500) Train Loss: nan, Train Steps/Sec: 1.33
341
+ [2026-02-01 21:00:24] (step=0032600) Train Loss: nan, Train Steps/Sec: 1.33
342
+ [2026-02-01 21:01:39] (step=0032700) Train Loss: nan, Train Steps/Sec: 1.33
343
+ [2026-02-01 21:02:54] (step=0032800) Train Loss: nan, Train Steps/Sec: 1.33
344
+ [2026-02-01 21:04:09] (step=0032900) Train Loss: nan, Train Steps/Sec: 1.33
345
+ [2026-02-01 21:05:25] (step=0033000) Train Loss: nan, Train Steps/Sec: 1.33
346
+ [2026-02-01 21:06:40] (step=0033100) Train Loss: nan, Train Steps/Sec: 1.33
347
+ [2026-02-01 21:07:55] (step=0033200) Train Loss: nan, Train Steps/Sec: 1.33
348
+ [2026-02-01 21:09:10] (step=0033300) Train Loss: nan, Train Steps/Sec: 1.33
349
+ [2026-02-01 21:10:25] (step=0033400) Train Loss: nan, Train Steps/Sec: 1.33
350
+ [2026-02-01 21:11:40] (step=0033500) Train Loss: nan, Train Steps/Sec: 1.33
351
+ [2026-02-01 21:12:56] (step=0033600) Train Loss: nan, Train Steps/Sec: 1.33
352
+ [2026-02-01 21:14:11] (step=0033700) Train Loss: nan, Train Steps/Sec: 1.33
353
+ [2026-02-01 21:15:26] (step=0033800) Train Loss: nan, Train Steps/Sec: 1.33
354
+ [2026-02-01 21:16:41] (step=0033900) Train Loss: nan, Train Steps/Sec: 1.33
355
+ [2026-02-01 21:17:56] (step=0034000) Train Loss: nan, Train Steps/Sec: 1.33
356
+ [2026-02-01 21:19:11] (step=0034100) Train Loss: nan, Train Steps/Sec: 1.33
357
+ [2026-02-01 21:20:27] (step=0034200) Train Loss: nan, Train Steps/Sec: 1.33
358
+ [2026-02-01 21:21:42] (step=0034300) Train Loss: nan, Train Steps/Sec: 1.33
359
+ [2026-02-01 21:22:57] (step=0034400) Train Loss: nan, Train Steps/Sec: 1.33
360
+ [2026-02-01 21:24:12] (step=0034500) Train Loss: nan, Train Steps/Sec: 1.33
361
+ [2026-02-01 21:25:27] (step=0034600) Train Loss: nan, Train Steps/Sec: 1.33
362
+ [2026-02-01 21:26:42] (step=0034700) Train Loss: nan, Train Steps/Sec: 1.33
363
+ [2026-02-01 21:27:57] (step=0034800) Train Loss: nan, Train Steps/Sec: 1.33
364
+ [2026-02-01 21:29:12] (step=0034900) Train Loss: nan, Train Steps/Sec: 1.33
365
+ [2026-02-01 21:30:28] (step=0035000) Train Loss: nan, Train Steps/Sec: 1.33
366
+ [2026-02-01 21:31:43] (step=0035100) Train Loss: nan, Train Steps/Sec: 1.33
367
+ [2026-02-01 21:32:58] (step=0035200) Train Loss: nan, Train Steps/Sec: 1.33
368
+ [2026-02-01 21:34:13] (step=0035300) Train Loss: nan, Train Steps/Sec: 1.33
369
+ [2026-02-01 21:35:28] (step=0035400) Train Loss: nan, Train Steps/Sec: 1.33
370
+ [2026-02-01 21:36:43] (step=0035500) Train Loss: nan, Train Steps/Sec: 1.33
371
+ [2026-02-01 21:37:58] (step=0035600) Train Loss: nan, Train Steps/Sec: 1.33
372
+ [2026-02-01 21:39:13] (step=0035700) Train Loss: nan, Train Steps/Sec: 1.33
373
+ [2026-02-01 21:40:29] (step=0035800) Train Loss: nan, Train Steps/Sec: 1.33
374
+ [2026-02-01 21:41:44] (step=0035900) Train Loss: nan, Train Steps/Sec: 1.33
375
+ [2026-02-01 21:42:59] (step=0036000) Train Loss: nan, Train Steps/Sec: 1.33
376
+ [2026-02-01 21:44:14] (step=0036100) Train Loss: nan, Train Steps/Sec: 1.33
377
+ [2026-02-01 21:45:29] (step=0036200) Train Loss: nan, Train Steps/Sec: 1.33
378
+ [2026-02-01 21:46:44] (step=0036300) Train Loss: nan, Train Steps/Sec: 1.33
379
+ [2026-02-01 21:48:00] (step=0036400) Train Loss: nan, Train Steps/Sec: 1.33
380
+ [2026-02-01 21:49:15] (step=0036500) Train Loss: nan, Train Steps/Sec: 1.33
381
+ [2026-02-01 21:50:30] (step=0036600) Train Loss: nan, Train Steps/Sec: 1.33
382
+ [2026-02-01 21:51:45] (step=0036700) Train Loss: nan, Train Steps/Sec: 1.33
383
+ [2026-02-01 21:53:00] (step=0036800) Train Loss: nan, Train Steps/Sec: 1.33
384
+ [2026-02-01 21:54:15] (step=0036900) Train Loss: nan, Train Steps/Sec: 1.33
385
+ [2026-02-01 21:55:31] (step=0037000) Train Loss: nan, Train Steps/Sec: 1.33
386
+ [2026-02-01 21:56:46] (step=0037100) Train Loss: nan, Train Steps/Sec: 1.33
387
+ [2026-02-01 21:58:01] (step=0037200) Train Loss: nan, Train Steps/Sec: 1.33
388
+ [2026-02-01 21:59:16] (step=0037300) Train Loss: nan, Train Steps/Sec: 1.33
389
+ [2026-02-01 22:00:31] (step=0037400) Train Loss: nan, Train Steps/Sec: 1.33
390
+ [2026-02-01 22:01:46] (step=0037500) Train Loss: nan, Train Steps/Sec: 1.33
391
+ [2026-02-01 22:03:01] (step=0037600) Train Loss: nan, Train Steps/Sec: 1.33
392
+ [2026-02-01 22:04:17] (step=0037700) Train Loss: nan, Train Steps/Sec: 1.33
393
+ [2026-02-01 22:05:32] (step=0037800) Train Loss: nan, Train Steps/Sec: 1.33
394
+ [2026-02-01 22:06:47] (step=0037900) Train Loss: nan, Train Steps/Sec: 1.33
395
+ [2026-02-01 22:08:02] (step=0038000) Train Loss: nan, Train Steps/Sec: 1.33
396
+ [2026-02-01 22:09:17] (step=0038100) Train Loss: nan, Train Steps/Sec: 1.33
397
+ [2026-02-01 22:10:32] (step=0038200) Train Loss: nan, Train Steps/Sec: 1.33
398
+ [2026-02-01 22:11:47] (step=0038300) Train Loss: nan, Train Steps/Sec: 1.33
399
+ [2026-02-01 22:13:02] (step=0038400) Train Loss: nan, Train Steps/Sec: 1.33
400
+ [2026-02-01 22:14:18] (step=0038500) Train Loss: nan, Train Steps/Sec: 1.33
401
+ [2026-02-01 22:15:33] (step=0038600) Train Loss: nan, Train Steps/Sec: 1.33
402
+ [2026-02-01 22:16:48] (step=0038700) Train Loss: nan, Train Steps/Sec: 1.33
403
+ [2026-02-01 22:18:03] (step=0038800) Train Loss: nan, Train Steps/Sec: 1.33
404
+ [2026-02-01 22:19:18] (step=0038900) Train Loss: nan, Train Steps/Sec: 1.33
405
+ [2026-02-01 22:20:33] (step=0039000) Train Loss: nan, Train Steps/Sec: 1.33
406
+ [2026-02-01 22:21:48] (step=0039100) Train Loss: nan, Train Steps/Sec: 1.33
407
+ [2026-02-01 22:23:04] (step=0039200) Train Loss: nan, Train Steps/Sec: 1.33
408
+ [2026-02-01 22:24:19] (step=0039300) Train Loss: nan, Train Steps/Sec: 1.33
409
+ [2026-02-01 22:25:34] (step=0039400) Train Loss: nan, Train Steps/Sec: 1.33
410
+ [2026-02-01 22:26:49] (step=0039500) Train Loss: nan, Train Steps/Sec: 1.33
411
+ [2026-02-01 22:28:04] (step=0039600) Train Loss: nan, Train Steps/Sec: 1.33
412
+ [2026-02-01 22:29:19] (step=0039700) Train Loss: nan, Train Steps/Sec: 1.33
413
+ [2026-02-01 22:30:34] (step=0039800) Train Loss: nan, Train Steps/Sec: 1.33
414
+ [2026-02-01 22:31:49] (step=0039900) Train Loss: nan, Train Steps/Sec: 1.33
415
+ [2026-02-01 22:33:04] (step=0040000) Train Loss: nan, Train Steps/Sec: 1.33
416
+ [2026-02-01 22:33:32] Beginning epoch 4...
417
+ [2026-02-01 22:34:22] (step=0040100) Train Loss: nan, Train Steps/Sec: 1.29
418
+ [2026-02-01 22:35:37] (step=0040200) Train Loss: nan, Train Steps/Sec: 1.33
419
+ [2026-02-01 22:36:52] (step=0040300) Train Loss: nan, Train Steps/Sec: 1.33
420
+ [2026-02-01 22:38:07] (step=0040400) Train Loss: nan, Train Steps/Sec: 1.33
421
+ [2026-02-01 22:39:22] (step=0040500) Train Loss: nan, Train Steps/Sec: 1.33
422
+ [2026-02-01 22:40:37] (step=0040600) Train Loss: nan, Train Steps/Sec: 1.33
423
+ [2026-02-01 22:41:52] (step=0040700) Train Loss: nan, Train Steps/Sec: 1.33
424
+ [2026-02-01 22:43:08] (step=0040800) Train Loss: nan, Train Steps/Sec: 1.33
425
+ [2026-02-01 22:44:23] (step=0040900) Train Loss: nan, Train Steps/Sec: 1.33
426
+ [2026-02-01 22:45:38] (step=0041000) Train Loss: nan, Train Steps/Sec: 1.33
427
+ [2026-02-01 22:46:53] (step=0041100) Train Loss: nan, Train Steps/Sec: 1.33
428
+ [2026-02-01 22:48:08] (step=0041200) Train Loss: nan, Train Steps/Sec: 1.33
429
+ [2026-02-01 22:49:23] (step=0041300) Train Loss: nan, Train Steps/Sec: 1.33
430
+ [2026-02-01 22:50:39] (step=0041400) Train Loss: nan, Train Steps/Sec: 1.33
431
+ [2026-02-01 22:51:54] (step=0041500) Train Loss: nan, Train Steps/Sec: 1.33
432
+ [2026-02-01 22:53:09] (step=0041600) Train Loss: nan, Train Steps/Sec: 1.33
433
+ [2026-02-01 22:54:24] (step=0041700) Train Loss: nan, Train Steps/Sec: 1.33
434
+ [2026-02-01 22:55:39] (step=0041800) Train Loss: nan, Train Steps/Sec: 1.33
435
+ [2026-02-01 22:56:54] (step=0041900) Train Loss: nan, Train Steps/Sec: 1.33
436
+ [2026-02-01 22:58:09] (step=0042000) Train Loss: nan, Train Steps/Sec: 1.33
437
+ [2026-02-01 22:59:24] (step=0042100) Train Loss: nan, Train Steps/Sec: 1.33
438
+ [2026-02-01 23:00:40] (step=0042200) Train Loss: nan, Train Steps/Sec: 1.33
439
+ [2026-02-01 23:01:55] (step=0042300) Train Loss: nan, Train Steps/Sec: 1.33
440
+ [2026-02-01 23:03:10] (step=0042400) Train Loss: nan, Train Steps/Sec: 1.33
441
+ [2026-02-01 23:04:25] (step=0042500) Train Loss: nan, Train Steps/Sec: 1.33
442
+ [2026-02-01 23:05:40] (step=0042600) Train Loss: nan, Train Steps/Sec: 1.33
443
+ [2026-02-01 23:06:55] (step=0042700) Train Loss: nan, Train Steps/Sec: 1.33
444
+ [2026-02-01 23:08:10] (step=0042800) Train Loss: nan, Train Steps/Sec: 1.33
445
+ [2026-02-01 23:09:26] (step=0042900) Train Loss: nan, Train Steps/Sec: 1.33
446
+ [2026-02-01 23:10:41] (step=0043000) Train Loss: nan, Train Steps/Sec: 1.33
447
+ [2026-02-01 23:11:56] (step=0043100) Train Loss: nan, Train Steps/Sec: 1.33
448
+ [2026-02-01 23:13:11] (step=0043200) Train Loss: nan, Train Steps/Sec: 1.33
449
+ [2026-02-01 23:14:26] (step=0043300) Train Loss: nan, Train Steps/Sec: 1.33
450
+ [2026-02-01 23:15:41] (step=0043400) Train Loss: nan, Train Steps/Sec: 1.33
451
+ [2026-02-01 23:16:56] (step=0043500) Train Loss: nan, Train Steps/Sec: 1.33
452
+ [2026-02-01 23:18:11] (step=0043600) Train Loss: nan, Train Steps/Sec: 1.33
453
+ [2026-02-01 23:19:27] (step=0043700) Train Loss: nan, Train Steps/Sec: 1.33
454
+ [2026-02-01 23:20:42] (step=0043800) Train Loss: nan, Train Steps/Sec: 1.33
455
+ [2026-02-01 23:21:57] (step=0043900) Train Loss: nan, Train Steps/Sec: 1.33
456
+ [2026-02-01 23:23:12] (step=0044000) Train Loss: nan, Train Steps/Sec: 1.33
457
+ [2026-02-01 23:24:27] (step=0044100) Train Loss: nan, Train Steps/Sec: 1.33
458
+ [2026-02-01 23:25:42] (step=0044200) Train Loss: nan, Train Steps/Sec: 1.33
459
+ [2026-02-01 23:26:57] (step=0044300) Train Loss: nan, Train Steps/Sec: 1.33
460
+ [2026-02-01 23:28:12] (step=0044400) Train Loss: nan, Train Steps/Sec: 1.33
461
+ [2026-02-01 23:29:28] (step=0044500) Train Loss: nan, Train Steps/Sec: 1.33
462
+ [2026-02-01 23:30:43] (step=0044600) Train Loss: nan, Train Steps/Sec: 1.33
463
+ [2026-02-01 23:31:58] (step=0044700) Train Loss: nan, Train Steps/Sec: 1.33
464
+ [2026-02-01 23:33:13] (step=0044800) Train Loss: nan, Train Steps/Sec: 1.33
465
+ [2026-02-01 23:34:28] (step=0044900) Train Loss: nan, Train Steps/Sec: 1.33
466
+ [2026-02-01 23:35:43] (step=0045000) Train Loss: nan, Train Steps/Sec: 1.33
467
+ [2026-02-01 23:36:58] (step=0045100) Train Loss: nan, Train Steps/Sec: 1.33
468
+ [2026-02-01 23:38:13] (step=0045200) Train Loss: nan, Train Steps/Sec: 1.33
469
+ [2026-02-01 23:39:28] (step=0045300) Train Loss: nan, Train Steps/Sec: 1.33
470
+ [2026-02-01 23:40:44] (step=0045400) Train Loss: nan, Train Steps/Sec: 1.33
471
+ [2026-02-01 23:41:59] (step=0045500) Train Loss: nan, Train Steps/Sec: 1.33
472
+ [2026-02-01 23:43:14] (step=0045600) Train Loss: nan, Train Steps/Sec: 1.33
473
+ [2026-02-01 23:44:29] (step=0045700) Train Loss: nan, Train Steps/Sec: 1.33
474
+ [2026-02-01 23:45:44] (step=0045800) Train Loss: nan, Train Steps/Sec: 1.33
475
+ [2026-02-01 23:46:59] (step=0045900) Train Loss: nan, Train Steps/Sec: 1.33
476
+ [2026-02-01 23:48:14] (step=0046000) Train Loss: nan, Train Steps/Sec: 1.33
477
+ [2026-02-01 23:49:29] (step=0046100) Train Loss: nan, Train Steps/Sec: 1.33
478
+ [2026-02-01 23:50:45] (step=0046200) Train Loss: nan, Train Steps/Sec: 1.33
479
+ [2026-02-01 23:52:00] (step=0046300) Train Loss: nan, Train Steps/Sec: 1.33
480
+ [2026-02-01 23:53:15] (step=0046400) Train Loss: nan, Train Steps/Sec: 1.33
481
+ [2026-02-01 23:54:30] (step=0046500) Train Loss: nan, Train Steps/Sec: 1.33
482
+ [2026-02-01 23:55:45] (step=0046600) Train Loss: nan, Train Steps/Sec: 1.33
483
+ [2026-02-01 23:57:00] (step=0046700) Train Loss: nan, Train Steps/Sec: 1.33
484
+ [2026-02-01 23:58:15] (step=0046800) Train Loss: nan, Train Steps/Sec: 1.33
485
+ [2026-02-01 23:59:31] (step=0046900) Train Loss: nan, Train Steps/Sec: 1.33
486
+ [2026-02-02 00:00:46] (step=0047000) Train Loss: nan, Train Steps/Sec: 1.33
487
+ [2026-02-02 00:02:01] (step=0047100) Train Loss: nan, Train Steps/Sec: 1.33
488
+ [2026-02-02 00:03:16] (step=0047200) Train Loss: nan, Train Steps/Sec: 1.33
489
+ [2026-02-02 00:04:31] (step=0047300) Train Loss: nan, Train Steps/Sec: 1.33
490
+ [2026-02-02 00:05:46] (step=0047400) Train Loss: nan, Train Steps/Sec: 1.33
491
+ [2026-02-02 00:07:01] (step=0047500) Train Loss: nan, Train Steps/Sec: 1.33
492
+ [2026-02-02 00:08:16] (step=0047600) Train Loss: nan, Train Steps/Sec: 1.33
493
+ [2026-02-02 00:09:32] (step=0047700) Train Loss: nan, Train Steps/Sec: 1.33
494
+ [2026-02-02 00:10:47] (step=0047800) Train Loss: nan, Train Steps/Sec: 1.33
495
+ [2026-02-02 00:12:02] (step=0047900) Train Loss: nan, Train Steps/Sec: 1.33
496
+ [2026-02-02 00:13:17] (step=0048000) Train Loss: nan, Train Steps/Sec: 1.33
497
+ [2026-02-02 00:14:32] (step=0048100) Train Loss: nan, Train Steps/Sec: 1.33
498
+ [2026-02-02 00:15:47] (step=0048200) Train Loss: nan, Train Steps/Sec: 1.33
499
+ [2026-02-02 00:17:02] (step=0048300) Train Loss: nan, Train Steps/Sec: 1.33
500
+ [2026-02-02 00:18:17] (step=0048400) Train Loss: nan, Train Steps/Sec: 1.33
501
+ [2026-02-02 00:19:32] (step=0048500) Train Loss: nan, Train Steps/Sec: 1.33
502
+ [2026-02-02 00:20:48] (step=0048600) Train Loss: nan, Train Steps/Sec: 1.33
503
+ [2026-02-02 00:22:03] (step=0048700) Train Loss: nan, Train Steps/Sec: 1.33
504
+ [2026-02-02 00:23:18] (step=0048800) Train Loss: nan, Train Steps/Sec: 1.33
505
+ [2026-02-02 00:24:33] (step=0048900) Train Loss: nan, Train Steps/Sec: 1.33
506
+ [2026-02-02 00:25:48] (step=0049000) Train Loss: nan, Train Steps/Sec: 1.33
507
+ [2026-02-02 00:27:03] (step=0049100) Train Loss: nan, Train Steps/Sec: 1.33
508
+ [2026-02-02 00:28:18] (step=0049200) Train Loss: nan, Train Steps/Sec: 1.33
509
+ [2026-02-02 00:29:33] (step=0049300) Train Loss: nan, Train Steps/Sec: 1.33
510
+ [2026-02-02 00:30:48] (step=0049400) Train Loss: nan, Train Steps/Sec: 1.33
511
+ [2026-02-02 00:32:04] (step=0049500) Train Loss: nan, Train Steps/Sec: 1.33
512
+ [2026-02-02 00:33:19] (step=0049600) Train Loss: nan, Train Steps/Sec: 1.33
513
+ [2026-02-02 00:34:34] (step=0049700) Train Loss: nan, Train Steps/Sec: 1.33
514
+ [2026-02-02 00:35:49] (step=0049800) Train Loss: nan, Train Steps/Sec: 1.33
515
+ [2026-02-02 00:37:04] (step=0049900) Train Loss: nan, Train Steps/Sec: 1.33
516
+ [2026-02-02 00:38:19] (step=0050000) Train Loss: nan, Train Steps/Sec: 1.33
517
+ 50000
518
+ [2026-02-02 00:38:20] Saved checkpoint to results_256_vp/depth-mu-2-000-SiT-XL-2-VP-velocity-None/checkpoints/0050000.pt
519
+ [2026-02-02 00:38:54] Beginning epoch 5...
520
+ [2026-02-02 00:39:37] (step=0050100) Train Loss: nan, Train Steps/Sec: 1.28
521
+ [2026-02-02 00:40:53] (step=0050200) Train Loss: nan, Train Steps/Sec: 1.33
522
+ [2026-02-02 00:42:08] (step=0050300) Train Loss: nan, Train Steps/Sec: 1.33
523
+ [2026-02-02 00:43:11] Generating EMA samples...
524
+ [2026-02-02 00:43:23] (step=0050400) Train Loss: nan, Train Steps/Sec: 1.33
525
+ [2026-02-02 00:44:38] (step=0050500) Train Loss: nan, Train Steps/Sec: 1.33
526
+ [2026-02-02 00:45:53] (step=0050600) Train Loss: nan, Train Steps/Sec: 1.33
527
+ [2026-02-02 00:47:08] (step=0050700) Train Loss: nan, Train Steps/Sec: 1.33
528
+ [2026-02-02 00:48:23] (step=0050800) Train Loss: nan, Train Steps/Sec: 1.33
529
+ [2026-02-02 00:49:38] (step=0050900) Train Loss: nan, Train Steps/Sec: 1.33
530
+ [2026-02-02 00:50:53] (step=0051000) Train Loss: nan, Train Steps/Sec: 1.33
531
+ [2026-02-02 00:52:09] (step=0051100) Train Loss: nan, Train Steps/Sec: 1.33
532
+ [2026-02-02 00:53:24] (step=0051200) Train Loss: nan, Train Steps/Sec: 1.33
533
+ [2026-02-02 00:54:39] (step=0051300) Train Loss: nan, Train Steps/Sec: 1.33
534
+ [2026-02-02 00:55:54] (step=0051400) Train Loss: nan, Train Steps/Sec: 1.33
535
+ [2026-02-02 00:57:09] (step=0051500) Train Loss: nan, Train Steps/Sec: 1.33
536
+ [2026-02-02 00:58:24] (step=0051600) Train Loss: nan, Train Steps/Sec: 1.33
537
+ [2026-02-02 00:59:39] (step=0051700) Train Loss: nan, Train Steps/Sec: 1.33
538
+ [2026-02-02 01:00:54] (step=0051800) Train Loss: nan, Train Steps/Sec: 1.33
539
+ [2026-02-02 01:02:10] (step=0051900) Train Loss: nan, Train Steps/Sec: 1.33
540
+ [2026-02-02 01:03:25] (step=0052000) Train Loss: nan, Train Steps/Sec: 1.33
541
+ [2026-02-02 01:04:40] (step=0052100) Train Loss: nan, Train Steps/Sec: 1.33
542
+ [2026-02-02 01:05:55] (step=0052200) Train Loss: nan, Train Steps/Sec: 1.33
543
+ [2026-02-02 01:07:10] (step=0052300) Train Loss: nan, Train Steps/Sec: 1.33
544
+ [2026-02-02 01:08:25] (step=0052400) Train Loss: nan, Train Steps/Sec: 1.33
545
+ [2026-02-02 01:09:41] (step=0052500) Train Loss: nan, Train Steps/Sec: 1.33
546
+ [2026-02-02 01:10:56] (step=0052600) Train Loss: nan, Train Steps/Sec: 1.33
547
+ [2026-02-02 01:12:11] (step=0052700) Train Loss: nan, Train Steps/Sec: 1.33
548
+ [2026-02-02 01:13:26] (step=0052800) Train Loss: nan, Train Steps/Sec: 1.33
549
+ [2026-02-02 01:14:41] (step=0052900) Train Loss: nan, Train Steps/Sec: 1.33
550
+ [2026-02-02 01:15:57] (step=0053000) Train Loss: nan, Train Steps/Sec: 1.33
551
+ [2026-02-02 01:17:12] (step=0053100) Train Loss: nan, Train Steps/Sec: 1.33
552
+ [2026-02-02 01:18:27] (step=0053200) Train Loss: nan, Train Steps/Sec: 1.33
553
+ [2026-02-02 01:19:42] (step=0053300) Train Loss: nan, Train Steps/Sec: 1.33
554
+ [2026-02-02 01:20:57] (step=0053400) Train Loss: nan, Train Steps/Sec: 1.33
555
+ [2026-02-02 01:22:13] (step=0053500) Train Loss: nan, Train Steps/Sec: 1.33
556
+ [2026-02-02 01:23:28] (step=0053600) Train Loss: nan, Train Steps/Sec: 1.33
557
+ [2026-02-02 01:24:43] (step=0053700) Train Loss: nan, Train Steps/Sec: 1.33
558
+ [2026-02-02 01:25:58] (step=0053800) Train Loss: nan, Train Steps/Sec: 1.33
559
+ [2026-02-02 01:27:13] (step=0053900) Train Loss: nan, Train Steps/Sec: 1.33
560
+ [2026-02-02 01:28:28] (step=0054000) Train Loss: nan, Train Steps/Sec: 1.33
561
+ [2026-02-02 01:29:44] (step=0054100) Train Loss: nan, Train Steps/Sec: 1.33
562
+ [2026-02-02 01:30:59] (step=0054200) Train Loss: nan, Train Steps/Sec: 1.33
563
+ [2026-02-02 01:32:14] (step=0054300) Train Loss: nan, Train Steps/Sec: 1.33
564
+ [2026-02-02 01:33:29] (step=0054400) Train Loss: nan, Train Steps/Sec: 1.33
565
+ [2026-02-02 01:34:44] (step=0054500) Train Loss: nan, Train Steps/Sec: 1.33
566
+ [2026-02-02 01:35:59] (step=0054600) Train Loss: nan, Train Steps/Sec: 1.33
567
+ [2026-02-02 01:37:15] (step=0054700) Train Loss: nan, Train Steps/Sec: 1.33
568
+ [2026-02-02 01:38:30] (step=0054800) Train Loss: nan, Train Steps/Sec: 1.33
569
+ [2026-02-02 01:39:45] (step=0054900) Train Loss: nan, Train Steps/Sec: 1.33
570
+ [2026-02-02 01:41:00] (step=0055000) Train Loss: nan, Train Steps/Sec: 1.33
571
+ [2026-02-02 01:42:15] (step=0055100) Train Loss: nan, Train Steps/Sec: 1.33
572
+ [2026-02-02 01:43:30] (step=0055200) Train Loss: nan, Train Steps/Sec: 1.33
573
+ [2026-02-02 01:44:46] (step=0055300) Train Loss: nan, Train Steps/Sec: 1.33
574
+ [2026-02-02 01:46:01] (step=0055400) Train Loss: nan, Train Steps/Sec: 1.33
575
+ [2026-02-02 01:47:16] (step=0055500) Train Loss: nan, Train Steps/Sec: 1.33
576
+ [2026-02-02 01:48:31] (step=0055600) Train Loss: nan, Train Steps/Sec: 1.33
577
+ [2026-02-02 01:49:46] (step=0055700) Train Loss: nan, Train Steps/Sec: 1.33
578
+ [2026-02-02 01:51:02] (step=0055800) Train Loss: nan, Train Steps/Sec: 1.33
579
+ [2026-02-02 01:52:17] (step=0055900) Train Loss: nan, Train Steps/Sec: 1.33
580
+ [2026-02-02 01:53:32] (step=0056000) Train Loss: nan, Train Steps/Sec: 1.33
581
+ [2026-02-02 01:54:47] (step=0056100) Train Loss: nan, Train Steps/Sec: 1.33
582
+ [2026-02-02 01:56:02] (step=0056200) Train Loss: nan, Train Steps/Sec: 1.33
583
+ [2026-02-02 01:57:17] (step=0056300) Train Loss: nan, Train Steps/Sec: 1.33
584
+ [2026-02-02 01:58:32] (step=0056400) Train Loss: nan, Train Steps/Sec: 1.33
585
+ [2026-02-02 01:59:48] (step=0056500) Train Loss: nan, Train Steps/Sec: 1.33
586
+ [2026-02-02 02:01:03] (step=0056600) Train Loss: nan, Train Steps/Sec: 1.33
587
+ [2026-02-02 02:02:18] (step=0056700) Train Loss: nan, Train Steps/Sec: 1.33
588
+ [2026-02-02 02:03:33] (step=0056800) Train Loss: nan, Train Steps/Sec: 1.33
589
+ [2026-02-02 02:04:48] (step=0056900) Train Loss: nan, Train Steps/Sec: 1.33
590
+ [2026-02-02 02:06:04] (step=0057000) Train Loss: nan, Train Steps/Sec: 1.33
591
+ [2026-02-02 02:07:19] (step=0057100) Train Loss: nan, Train Steps/Sec: 1.33
592
+ [2026-02-02 02:08:34] (step=0057200) Train Loss: nan, Train Steps/Sec: 1.33
593
+ [2026-02-02 02:09:49] (step=0057300) Train Loss: nan, Train Steps/Sec: 1.33
594
+ [2026-02-02 02:11:04] (step=0057400) Train Loss: nan, Train Steps/Sec: 1.33
595
+ [2026-02-02 02:12:19] (step=0057500) Train Loss: nan, Train Steps/Sec: 1.33
596
+ [2026-02-02 02:13:35] (step=0057600) Train Loss: nan, Train Steps/Sec: 1.33
597
+ [2026-02-02 02:14:50] (step=0057700) Train Loss: nan, Train Steps/Sec: 1.33
598
+ [2026-02-02 02:16:05] (step=0057800) Train Loss: nan, Train Steps/Sec: 1.33
599
+ [2026-02-02 02:17:20] (step=0057900) Train Loss: nan, Train Steps/Sec: 1.33
600
+ [2026-02-02 02:18:35] (step=0058000) Train Loss: nan, Train Steps/Sec: 1.33
601
+ [2026-02-02 02:19:50] (step=0058100) Train Loss: nan, Train Steps/Sec: 1.33
602
+ [2026-02-02 02:21:05] (step=0058200) Train Loss: nan, Train Steps/Sec: 1.33
603
+ [2026-02-02 02:22:21] (step=0058300) Train Loss: nan, Train Steps/Sec: 1.33
604
+ [2026-02-02 02:23:36] (step=0058400) Train Loss: nan, Train Steps/Sec: 1.33
605
+ [2026-02-02 02:24:51] (step=0058500) Train Loss: nan, Train Steps/Sec: 1.33
606
+ [2026-02-02 02:26:06] (step=0058600) Train Loss: nan, Train Steps/Sec: 1.33
607
+ [2026-02-02 02:27:21] (step=0058700) Train Loss: nan, Train Steps/Sec: 1.33
608
+ [2026-02-02 02:28:36] (step=0058800) Train Loss: nan, Train Steps/Sec: 1.33
609
+ [2026-02-02 02:29:51] (step=0058900) Train Loss: nan, Train Steps/Sec: 1.33
610
+ [2026-02-02 02:31:06] (step=0059000) Train Loss: nan, Train Steps/Sec: 1.33
611
+ [2026-02-02 02:32:21] (step=0059100) Train Loss: nan, Train Steps/Sec: 1.33
612
+ [2026-02-02 02:33:36] (step=0059200) Train Loss: nan, Train Steps/Sec: 1.33
613
+ [2026-02-02 02:34:51] (step=0059300) Train Loss: nan, Train Steps/Sec: 1.33
614
+ [2026-02-02 02:36:07] (step=0059400) Train Loss: nan, Train Steps/Sec: 1.33
615
+ [2026-02-02 02:38:31] (step=0059500) Train Loss: nan, Train Steps/Sec: 0.69
616
+ [2026-02-02 02:41:15] (step=0059600) Train Loss: nan, Train Steps/Sec: 0.61
617
+ [2026-02-02 02:43:58] (step=0059700) Train Loss: nan, Train Steps/Sec: 0.61
618
+ [2026-02-02 02:46:41] (step=0059800) Train Loss: nan, Train Steps/Sec: 0.61
619
+ [2026-02-02 02:49:24] (step=0059900) Train Loss: nan, Train Steps/Sec: 0.61
620
+ [2026-02-02 02:52:09] (step=0060000) Train Loss: nan, Train Steps/Sec: 0.61
621
+ [2026-02-02 02:53:37] Beginning epoch 6...
622
+ [2026-02-02 02:54:54] (step=0060100) Train Loss: nan, Train Steps/Sec: 0.61
623
+ [2026-02-02 02:57:37] (step=0060200) Train Loss: nan, Train Steps/Sec: 0.61
624
+ [2026-02-02 03:00:21] (step=0060300) Train Loss: nan, Train Steps/Sec: 0.61
625
+ [2026-02-02 03:03:04] (step=0060400) Train Loss: nan, Train Steps/Sec: 0.61
626
+ [2026-02-02 03:05:46] (step=0060500) Train Loss: nan, Train Steps/Sec: 0.62
627
+ [2026-02-02 03:08:31] (step=0060600) Train Loss: nan, Train Steps/Sec: 0.61
628
+ [2026-02-02 03:11:14] (step=0060700) Train Loss: nan, Train Steps/Sec: 0.61
Rectified_Noise/GVP-Disp/权重类型分析.md ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 损失函数权重类型分析
2
+
3
+ ## 代码位置
4
+ `transport/transport.py` 第150-156行
5
+
6
+ ## 三种权重类型
7
+
8
+ ### 1. WeightType.NONE
9
+ ```python
10
+ weight = 1
11
+ ```
12
+ **特点:**
13
+ - 均匀权重,所有时间步的损失贡献相同
14
+ - 最简单的权重策略
15
+
16
+ **影响:**
17
+ - 对训练过程的影响:所有时间步 t 的损失被同等对待
18
+ - 优点:实现简单,训练稳定
19
+ - 缺点:可能忽略不同时间步的重要性差异
20
+
21
+ ### 2. WeightType.VELOCITY
22
+ ```python
23
+ weight = (drift_var / sigma_t) ** 2
24
+ ```
25
+ **特点:**
26
+ - 权重与 `(drift_var / sigma_t)²` 成正比
27
+ - `drift_var` 是扩散系数(diffusion coefficient)
28
+ - `sigma_t` 是噪声系数(noise coefficient)
29
+
30
+ **数学含义:**
31
+ - 对于线性路径(ICPlan):`sigma_t = 1 - t`,`drift_var` 是扩散项
32
+ - 权重 = `(扩散系数 / 噪声系数)²`
33
+ - 当 `sigma_t` 较小时(接近 t=1,噪声少),权重较大
34
+ - 当 `sigma_t` 较大时(接近 t=0,噪声多),权重较小
35
+
36
+ **影响:**
37
+ - **强调后期时间步**:在去噪过程的后期(t接近1,噪声少)给予更高权重
38
+ - **训练重点**:模型在低噪声区域的预测精度更重要
39
+ - **适用场景**:当最终生成质量(低噪声区域)是关键时
40
+
41
+ **时间依赖行为:**
42
+ - t → 0(高噪声):`sigma_t` 大 → 权重小
43
+ - t → 1(低噪声):`sigma_t` 小 → 权重大
44
+
45
+ ### 3. WeightType.LIKELIHOOD
46
+ ```python
47
+ weight = drift_var / (sigma_t ** 2)
48
+ ```
49
+ **特点:**
50
+ - 权重与 `drift_var / sigma_t²` 成正比
51
+ - 相比 VELOCITY 权重,分母是 `sigma_t²` 而不是 `sigma_t`
52
+
53
+ **数学含义:**
54
+ - 权重 = `扩散系数 / 噪声系数²`
55
+ - 当 `sigma_t` 较小时,`sigma_t²` 更小,权重更大
56
+ - 当 `sigma_t` 较大时,`sigma_t²` 更大,权重更小
57
+
58
+ **影响:**
59
+ - **更强烈地强调后期时间步**:相比 VELOCITY,对低噪声区域的权重更大
60
+ - **训练重点**:极大化模型在低噪声区域的预测精度
61
+ - **适用场景**:当需要最大化似然或生成质量时
62
+
63
+ **与 VELOCITY 的对比:**
64
+ - LIKELIHOOD 权重 = VELOCITY 权重 × `(1 / sigma_t)`
65
+ - 在相同 `drift_var` 和 `sigma_t` 下,LIKELIHOOD 权重总是 ≥ VELOCITY 权重
66
+ - LIKELIHOOD 对后期时间步的强调更极端
67
+
68
+ ## 权重随时间的典型行为(线性路径示例)
69
+
70
+ 假设线性路径(ICPlan):
71
+ - `sigma_t = 1 - t`(从 1 到 0)
72
+ - `drift_var` 通常与 `t` 相关
73
+
74
+ ### 时间步 t=0.1(高噪声)
75
+ - `sigma_t ≈ 0.9`
76
+ - **NONE**: weight = 1
77
+ - **VELOCITY**: weight = `(drift_var / 0.9)²` ≈ 中等
78
+ - **LIKELIHOOD**: weight = `drift_var / 0.81` ≈ 较大
79
+
80
+ ### 时间步 t=0.5(中等噪声)
81
+ - `sigma_t = 0.5`
82
+ - **NONE**: weight = 1
83
+ - **VELOCITY**: weight = `(drift_var / 0.5)²` = `4 × drift_var²`
84
+ - **LIKELIHOOD**: weight = `drift_var / 0.25` = `4 × drift_var`
85
+
86
+ ### 时间步 t=0.9(低噪声)
87
+ - `sigma_t ≈ 0.1`
88
+ - **NONE**: weight = 1
89
+ - **VELOCITY**: weight = `(drift_var / 0.1)²` = `100 × drift_var²`(很大)
90
+ - **LIKELIHOOD**: weight = `drift_var / 0.01` = `100 × drift_var`(非常大)
91
+
92
+ ## 实际损失计算
93
+
94
+ ### 对于 NOISE 模型类型:
95
+ ```python
96
+ loss = mean_flat(weight * ((model_output - x0) ** 2))
97
+ ```
98
+
99
+ ### 对于 SCORE 模型类型:
100
+ ```python
101
+ loss = mean_flat(weight * ((model_output * sigma_t + x0) ** 2))
102
+ ```
103
+
104
+ ## 选择建议
105
+
106
+ 1. **WeightType.NONE**
107
+ - 适合:简单实验、基线对比
108
+ - 优点:训练稳定,实现简单
109
+ - 缺点:可能忽略时间步重要性
110
+
111
+ 2. **WeightType.VELOCITY**
112
+ - 适合:关注最终生成质量
113
+ - 优点:强调低噪声区域,生成质量通常更好
114
+ - 缺点:可能在高噪声区域训练不足
115
+
116
+ 3. **WeightType.LIKELIHOOD**
117
+ - 适合:需要最大化似然、追求最高生成质量
118
+ - 优点:最强调低噪声区域
119
+ - 缺点:可能在高噪声区域训练严重不足,训练可能不稳定
120
+
121
+ ## 总结
122
+
123
+ 三种权重类型形成了一个从均匀到极端强调后期的梯度:
124
+
125
+ ```
126
+ NONE (均匀) < VELOCITY (强调后期) < LIKELIHOOD (极端强调后期)
127
+ ```
128
+
129
+ 选择哪种权重取决于:
130
+ - 训练目标(生成质量 vs 训练稳定性)
131
+ - 数据特性
132
+ - 模型类型(NOISE vs SCORE)
133
+ - 路径类型(LINEAR vs GVP vs VP)
Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000032.png ADDED
Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000077.png ADDED
Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000133.png ADDED
Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000161.png ADDED
Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000220.png ADDED
Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000331.png ADDED
Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000387.png ADDED
Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000505.png ADDED
Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000517.png ADDED
Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000551.png ADDED
Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000726.png ADDED
Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000817.png ADDED
Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000865.png ADDED
Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000914.png ADDED
Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/000940.png ADDED
Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/001043.png ADDED
Rectified_Noise/VP-Disp/VP_samples/depth-mu-2-threshold-0.0-0175000-base-cfg-1.0-64-SDE-100-Euler-sigma-Mean-0.04/001210.png ADDED
SiT_back/SiT_clean/W_training.log ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ nohup: ignoring input
2
+ W1124 10:39:29.690000 58030 site-packages/torch/distributed/run.py:793]
3
+ W1124 10:39:29.690000 58030 site-packages/torch/distributed/run.py:793] *****************************************
4
+ W1124 10:39:29.690000 58030 site-packages/torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
5
+ W1124 10:39:29.690000 58030 site-packages/torch/distributed/run.py:793] *****************************************
6
+ [NOTICE] The application is pending for GPU resource in asynchronous queue. The longest waiting time in queue is 1800 seconds.
7
+ [NOTICE] The application is pending for GPU resource in asynchronous queue. The longest waiting time in queue is 1800 seconds.
8
+ [NOTICE] The application is pending for GPU resource in asynchronous queue. The longest waiting time in queue is 1800 seconds.
9
+ [NOTICE] The application is pending for GPU resource in asynchronous queue. The longest waiting time in queue is 1800 seconds.
10
+ Starting rank=0, seed=0, world_size=4.
11
+ [2025-11-24 10:39:48] Experiment directory created at results/005-SiT-XL-2-Linear-velocity-None
12
+ [2025-11-24 10:39:48] Sample images will be saved to results/005-SiT-XL-2-Linear-velocity-None/pic
13
+ Starting rank=2, seed=2, world_size=4.
14
+ Starting rank=1, seed=1, world_size=4.
15
+ Starting rank=3, seed=3, world_size=4.
16
+ [2025-11-24 10:40:02] SiT Parameters: 675,129,632
17
+ [2025-11-24 10:40:04] Dataset contains 1,281,167 images (/gemini/platform/public/hzh/datasets/Imagenet/train/)
18
+ [2025-11-24 10:40:04] Training for 140000 epochs...
19
+ [2025-11-24 10:40:04] Beginning epoch 0...
20
+ [2025-11-24 10:40:24] Saved checkpoint to results/005-SiT-XL-2-Linear-velocity-None/checkpoints/0000010.pt
21
+ [2025-11-24 10:40:24] Generating EMA samples...
22
+ [2025-11-24 10:40:25] Saved sample images grid to results/005-SiT-XL-2-Linear-velocity-None/pic/step_0000010_samples_grid.png
23
+ [2025-11-24 10:40:25] Generating EMA samples done.
24
+ W1124 10:40:39.173000 58030 site-packages/torch/distributed/elastic/agent/server/api.py:704] Received 2 death signal, shutting down workers
25
+ W1124 10:40:39.173000 58030 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 58079 closing signal SIGINT
26
+ W1124 10:40:39.174000 58030 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 58080 closing signal SIGINT
27
+ W1124 10:40:39.174000 58030 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 58081 closing signal SIGINT
28
+ W1124 10:40:39.174000 58030 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 58082 closing signal SIGINT
29
+ [rank0]: Traceback (most recent call last):
30
+ [rank0]: File "/gemini/space/gzy_new/Noise_Matching/SiT_clean/train.py", line 371, in <module>
31
+ [rank0]: main(args)
32
+ [rank0]: File "/gemini/space/gzy_new/Noise_Matching/SiT_clean/train.py", line 298, in main
33
+ [rank0]: torch.save(checkpoint, checkpoint_path)
34
+ [rank0]: File "/opt/conda/envs/SiT/lib/python3.12/site-packages/torch/serialization.py", line 850, in save
35
+ [rank0]: _save(
36
+ [rank0]: File "/opt/conda/envs/SiT/lib/python3.12/site-packages/torch/serialization.py", line 1114, in _save
37
+ [rank0]: zip_file.write_record(name, storage, num_bytes)
38
+ [rank0]: KeyboardInterrupt
39
+ W1124 10:40:39.380000 58030 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 58079 closing signal SIGTERM
40
+ W1124 10:40:39.380000 58030 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 58080 closing signal SIGTERM
41
+ W1124 10:40:39.381000 58030 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 58081 closing signal SIGTERM
42
+ W1124 10:40:39.381000 58030 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 58082 closing signal SIGTERM
43
+ Traceback (most recent call last):
44
+ File "/opt/conda/envs/SiT/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/api.py", line 696, in run
45
+ result = self._invoke_run(role)
46
+ ^^^^^^^^^^^^^^^^^^^^^^
47
+ File "/opt/conda/envs/SiT/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/api.py", line 855, in _invoke_run
48
+ time.sleep(monitor_interval)
49
+ File "/opt/conda/envs/SiT/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 84, in _terminate_process_handler
50
+ raise SignalException(f"Process {os.getpid()} got signal: {sigval}", sigval=sigval)
51
+ torch.distributed.elastic.multiprocessing.api.SignalException: Process 58030 got signal: 2
52
+
53
+ During handling of the above exception, another exception occurred:
54
+
55
+ Traceback (most recent call last):
56
+ File "/opt/conda/envs/SiT/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/api.py", line 705, in run
57
+ self._shutdown(e.sigval)
58
+ File "/opt/conda/envs/SiT/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/local_elastic_agent.py", line 365, in _shutdown
59
+ self._pcontext.close(death_sig)
60
+ File "/opt/conda/envs/SiT/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 572, in close
61
+ self._close(death_sig=death_sig, timeout=timeout)
62
+ File "/opt/conda/envs/SiT/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 909, in _close
63
+ handler.proc.wait(time_to_wait)
64
+ File "/opt/conda/envs/SiT/lib/python3.12/subprocess.py", line 1266, in wait
65
+ return self._wait(timeout=timeout)
66
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^
67
+ File "/opt/conda/envs/SiT/lib/python3.12/subprocess.py", line 2055, in _wait
68
+ time.sleep(delay)
69
+ File "/opt/conda/envs/SiT/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 84, in _terminate_process_handler
70
+ raise SignalException(f"Process {os.getpid()} got signal: {sigval}", sigval=sigval)
71
+ torch.distributed.elastic.multiprocessing.api.SignalException: Process 58030 got signal: 2
72
+
73
+ During handling of the above exception, another exception occurred:
74
+
75
+ Traceback (most recent call last):
76
+ File "/opt/conda/envs/SiT/bin/torchrun", line 33, in <module>
77
+ sys.exit(load_entry_point('torch==2.5.1', 'console_scripts', 'torchrun')())
78
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
79
+ File "/opt/conda/envs/SiT/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper
80
+ return f(*args, **kwargs)
81
+ ^^^^^^^^^^^^^^^^^^
82
+ File "/opt/conda/envs/SiT/lib/python3.12/site-packages/torch/distributed/run.py", line 919, in main
83
+ run(args)
84
+ File "/opt/conda/envs/SiT/lib/python3.12/site-packages/torch/distributed/run.py", line 910, in run
85
+ elastic_launch(
86
+ File "/opt/conda/envs/SiT/lib/python3.12/site-packages/torch/distributed/launcher/api.py", line 138, in __call__
87
+ return launch_agent(self._config, self._entrypoint, list(args))
88
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
89
+ File "/opt/conda/envs/SiT/lib/python3.12/site-packages/torch/distributed/launcher/api.py", line 260, in launch_agent
90
+ result = agent.run()
91
+ ^^^^^^^^^^^
92
+ File "/opt/conda/envs/SiT/lib/python3.12/site-packages/torch/distributed/elastic/metrics/api.py", line 137, in wrapper
93
+ result = f(*args, **kwargs)
94
+ ^^^^^^^^^^^^^^^^^^
95
+ File "/opt/conda/envs/SiT/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/api.py", line 710, in run
96
+ self._shutdown()
97
+ File "/opt/conda/envs/SiT/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/local_elastic_agent.py", line 365, in _shutdown
98
+ self._pcontext.close(death_sig)
99
+ File "/opt/conda/envs/SiT/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 572, in close
100
+ self._close(death_sig=death_sig, timeout=timeout)
101
+ File "/opt/conda/envs/SiT/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 909, in _close
102
+ handler.proc.wait(time_to_wait)
103
+ File "/opt/conda/envs/SiT/lib/python3.12/subprocess.py", line 1266, in wait
104
+ return self._wait(timeout=timeout)
105
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^
106
+ File "/opt/conda/envs/SiT/lib/python3.12/subprocess.py", line 2055, in _wait
107
+ time.sleep(delay)
108
+ File "/opt/conda/envs/SiT/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 84, in _terminate_process_handler
109
+ raise SignalException(f"Process {os.getpid()} got signal: {sigval}", sigval=sigval)
110
+ torch.distributed.elastic.multiprocessing.api.SignalException: Process 58030 got signal: 2
SiT_back/SiT_clean/__pycache__/download.cpython-312.pyc ADDED
Binary file (1.99 kB). View file
 
SiT_back/SiT_clean/__pycache__/models.cpython-312.pyc ADDED
Binary file (20.8 kB). View file
 
SiT_back/SiT_clean/__pycache__/train_utils.cpython-312.pyc ADDED
Binary file (2.84 kB). View file
 
SiT_back/SiT_clean/download.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This source code is licensed under the license found in the
2
+ # LICENSE file in the root directory of this source tree.
3
+
4
+ """
5
+ Functions for downloading pre-trained SiT models
6
+ """
7
+ from torchvision.datasets.utils import download_url
8
+ import torch
9
+ import os
10
+
11
+
12
+ pretrained_models = {'SiT-XL-2-256x256.pt'}
13
+
14
+
15
+ def find_model(model_name):
16
+ """
17
+ Finds a pre-trained SiT model, downloading it if necessary. Alternatively, loads a model from a local path.
18
+ """
19
+ if model_name in pretrained_models:
20
+ return download_model(model_name)
21
+ else:
22
+ assert os.path.isfile(model_name), f'Could not find SiT checkpoint at {model_name}'
23
+ checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage)
24
+ if "ema" in checkpoint: # supports checkpoints from train.py
25
+ checkpoint = checkpoint["ema"]
26
+ return checkpoint
27
+
28
+
29
+ def download_model(model_name):
30
+ """
31
+ Downloads a pre-trained SiT model from the web.
32
+ """
33
+ assert model_name in pretrained_models
34
+ local_path = f'pretrained_models/{model_name}'
35
+ if not os.path.isfile(local_path):
36
+ os.makedirs('pretrained_models', exist_ok=True)
37
+ web_path = f'https://www.dl.dropboxusercontent.com/scl/fi/as9oeomcbub47de5g4be0/SiT-XL-2-256.pt?rlkey=uxzxmpicu46coq3msb17b9ofa&dl=0'
38
+ download_url(web_path, 'pretrained_models', filename=model_name)
39
+ model = torch.load(local_path, map_location=lambda storage, loc: storage)
40
+ return model
SiT_back/SiT_clean/models.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This source code is licensed under the license found in the
2
+ # LICENSE file in the root directory of this source tree.
3
+ # --------------------------------------------------------
4
+ # References:
5
+ # GLIDE: https://github.com/openai/glide-text2im
6
+ # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
7
+ # --------------------------------------------------------
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import numpy as np
12
+ import math
13
+ from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
14
+
15
+
16
+ def modulate(x, shift, scale):
17
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
18
+
19
+
20
+ #################################################################################
21
+ # Embedding Layers for Timesteps and Class Labels #
22
+ #################################################################################
23
+
24
+ class TimestepEmbedder(nn.Module):
25
+ """
26
+ Embeds scalar timesteps into vector representations.
27
+ """
28
+ def __init__(self, hidden_size, frequency_embedding_size=256):
29
+ super().__init__()
30
+ self.mlp = nn.Sequential(
31
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
32
+ nn.SiLU(),
33
+ nn.Linear(hidden_size, hidden_size, bias=True),
34
+ )
35
+ self.frequency_embedding_size = frequency_embedding_size
36
+
37
+ @staticmethod
38
+ def timestep_embedding(t, dim, max_period=10000):
39
+ """
40
+ Create sinusoidal timestep embeddings.
41
+ :param t: a 1-D Tensor of N indices, one per batch element.
42
+ These may be fractional.
43
+ :param dim: the dimension of the output.
44
+ :param max_period: controls the minimum frequency of the embeddings.
45
+ :return: an (N, D) Tensor of positional embeddings.
46
+ """
47
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
48
+ half = dim // 2
49
+ freqs = torch.exp(
50
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
51
+ ).to(device=t.device)
52
+ args = t[:, None].float() * freqs[None]
53
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
54
+ if dim % 2:
55
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
56
+ return embedding
57
+
58
+ def forward(self, t):
59
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
60
+ t_emb = self.mlp(t_freq)
61
+ return t_emb
62
+
63
+
64
+ class LabelEmbedder(nn.Module):
65
+ """
66
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
67
+ """
68
+ def __init__(self, num_classes, hidden_size, dropout_prob):
69
+ super().__init__()
70
+ use_cfg_embedding = dropout_prob > 0
71
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
72
+ self.num_classes = num_classes
73
+ self.dropout_prob = dropout_prob
74
+
75
+ def token_drop(self, labels, force_drop_ids=None):
76
+ """
77
+ Drops labels to enable classifier-free guidance.
78
+ """
79
+ if force_drop_ids is None:
80
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
81
+ else:
82
+ drop_ids = force_drop_ids == 1
83
+ labels = torch.where(drop_ids, self.num_classes, labels)
84
+ return labels
85
+
86
+ def forward(self, labels, train, force_drop_ids=None):
87
+ use_dropout = self.dropout_prob > 0
88
+ if (train and use_dropout) or (force_drop_ids is not None):
89
+ labels = self.token_drop(labels, force_drop_ids)
90
+ embeddings = self.embedding_table(labels)
91
+ return embeddings
92
+
93
+
94
+ #################################################################################
95
+ # Core SiT Model #
96
+ #################################################################################
97
+
98
+ class SiTBlock(nn.Module):
99
+ """
100
+ A SiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
101
+ """
102
+ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
103
+ super().__init__()
104
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
105
+ self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
106
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
107
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
108
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
109
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
110
+ self.adaLN_modulation = nn.Sequential(
111
+ nn.SiLU(),
112
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
113
+ )
114
+
115
+ def forward(self, x, c):
116
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
117
+ x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
118
+ x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
119
+ return x
120
+
121
+
122
+ class FinalLayer(nn.Module):
123
+ """
124
+ The final layer of SiT.
125
+ """
126
+ def __init__(self, hidden_size, patch_size, out_channels):
127
+ super().__init__()
128
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
129
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
130
+ self.adaLN_modulation = nn.Sequential(
131
+ nn.SiLU(),
132
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True)
133
+ )
134
+
135
+ def forward(self, x, c):
136
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
137
+ x = modulate(self.norm_final(x), shift, scale)
138
+ x = self.linear(x)
139
+ return x
140
+
141
+
142
+ class SiT(nn.Module):
143
+ """
144
+ Diffusion model with a Transformer backbone.
145
+ """
146
+ def __init__(
147
+ self,
148
+ input_size=32,
149
+ patch_size=2,
150
+ in_channels=4,
151
+ hidden_size=1152,
152
+ depth=28,
153
+ num_heads=16,
154
+ mlp_ratio=4.0,
155
+ class_dropout_prob=0.1,
156
+ num_classes=1000,
157
+ learn_sigma=True,
158
+ ):
159
+ super().__init__()
160
+ self.learn_sigma = learn_sigma
161
+ self.learn_sigma = True
162
+ self.in_channels = in_channels
163
+ self.out_channels = in_channels * 2
164
+ self.patch_size = patch_size
165
+ self.num_heads = num_heads
166
+
167
+ self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
168
+ self.t_embedder = TimestepEmbedder(hidden_size)
169
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
170
+ num_patches = self.x_embedder.num_patches
171
+ # Will use fixed sin-cos embedding:
172
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
173
+
174
+ self.blocks = nn.ModuleList([
175
+ SiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
176
+ ])
177
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
178
+ self.initialize_weights()
179
+
180
+ def initialize_weights(self):
181
+ # Initialize transformer layers:
182
+ def _basic_init(module):
183
+ if isinstance(module, nn.Linear):
184
+ torch.nn.init.xavier_uniform_(module.weight)
185
+ if module.bias is not None:
186
+ nn.init.constant_(module.bias, 0)
187
+ self.apply(_basic_init)
188
+
189
+ # Initialize (and freeze) pos_embed by sin-cos embedding:
190
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
191
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
192
+
193
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
194
+ w = self.x_embedder.proj.weight.data
195
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
196
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
197
+
198
+ # Initialize label embedding table:
199
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
200
+
201
+ # Initialize timestep embedding MLP:
202
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
203
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
204
+
205
+ # Zero-out adaLN modulation layers in SiT blocks:
206
+ for block in self.blocks:
207
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
208
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
209
+
210
+ # Zero-out output layers:
211
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
212
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
213
+ nn.init.constant_(self.final_layer.linear.weight, 0)
214
+ nn.init.constant_(self.final_layer.linear.bias, 0)
215
+
216
+ def unpatchify(self, x):
217
+ """
218
+ x: (N, T, patch_size**2 * C)
219
+ imgs: (N, H, W, C)
220
+ """
221
+ c = self.out_channels
222
+ p = self.x_embedder.patch_size[0]
223
+ h = w = int(x.shape[1] ** 0.5)
224
+ assert h * w == x.shape[1]
225
+
226
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
227
+ x = torch.einsum('nhwpqc->nchpwq', x)
228
+ imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
229
+ return imgs
230
+
231
+ def forward(self, x, t, y):
232
+ """
233
+ Forward pass of SiT.
234
+ x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
235
+ t: (N,) tensor of diffusion timesteps
236
+ y: (N,) tensor of class labels
237
+ """
238
+ x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
239
+ t = self.t_embedder(t) # (N, D)
240
+ y = self.y_embedder(y, self.training) # (N, D)
241
+ c = t + y # (N, D)
242
+ for block in self.blocks:
243
+ x = block(x, c) # (N, T, D)
244
+ x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
245
+ x = self.unpatchify(x) # (N, out_channels, H, W)
246
+ if self.learn_sigma:
247
+ x, _ = x.chunk(2, dim=1)
248
+ return x
249
+
250
+ def forward_with_cfg(self, x, t, y, cfg_scale):
251
+ """
252
+ Forward pass of SiT, but also batches the unconSiTional forward pass for classifier-free guidance.
253
+ """
254
+ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
255
+ half = x[: len(x) // 2]
256
+ combined = torch.cat([half, half], dim=0)
257
+ model_out = self.forward(combined, t, y)
258
+ # For exact reproducibility reasons, we apply classifier-free guidance on only
259
+ # three channels by default. The standard approach to cfg applies it to all channels.
260
+ # This can be done by uncommenting the following line and commenting-out the line following that.
261
+ # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
262
+ eps, rest = model_out[:, :3], model_out[:, 3:]
263
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
264
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
265
+ eps = torch.cat([half_eps, half_eps], dim=0)
266
+ return torch.cat([eps, rest], dim=1)
267
+
268
+
269
+ #################################################################################
270
+ # Sine/Cosine Positional Embedding Functions #
271
+ #################################################################################
272
+ # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
273
+
274
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
275
+ """
276
+ grid_size: int of the grid height and width
277
+ return:
278
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
279
+ """
280
+ grid_h = np.arange(grid_size, dtype=np.float32)
281
+ grid_w = np.arange(grid_size, dtype=np.float32)
282
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
283
+ grid = np.stack(grid, axis=0)
284
+
285
+ grid = grid.reshape([2, 1, grid_size, grid_size])
286
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
287
+ if cls_token and extra_tokens > 0:
288
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
289
+ return pos_embed
290
+
291
+
292
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
293
+ assert embed_dim % 2 == 0
294
+
295
+ # use half of dimensions to encode grid_h
296
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
297
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
298
+
299
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
300
+ return emb
301
+
302
+
303
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
304
+ """
305
+ embed_dim: output dimension for each position
306
+ pos: a list of positions to be encoded: size (M,)
307
+ out: (M, D)
308
+ """
309
+ assert embed_dim % 2 == 0
310
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
311
+ omega /= embed_dim / 2.
312
+ omega = 1. / 10000**omega # (D/2,)
313
+
314
+ pos = pos.reshape(-1) # (M,)
315
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
316
+
317
+ emb_sin = np.sin(out) # (M, D/2)
318
+ emb_cos = np.cos(out) # (M, D/2)
319
+
320
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
321
+ return emb
322
+
323
+
324
+ #################################################################################
325
+ # SiT Configs #
326
+ #################################################################################
327
+
328
+ def SiT_XL_2(**kwargs):
329
+ return SiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
330
+
331
+ def SiT_XL_4(**kwargs):
332
+ return SiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs)
333
+
334
+ def SiT_XL_8(**kwargs):
335
+ return SiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs)
336
+
337
+ def SiT_L_2(**kwargs):
338
+ return SiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs)
339
+
340
+ def SiT_L_4(**kwargs):
341
+ return SiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs)
342
+
343
+ def SiT_L_8(**kwargs):
344
+ return SiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs)
345
+
346
+ def SiT_B_2(**kwargs):
347
+ return SiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs)
348
+
349
+ def SiT_B_4(**kwargs):
350
+ return SiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs)
351
+
352
+ def SiT_B_8(**kwargs):
353
+ return SiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs)
354
+
355
+ def SiT_S_2(**kwargs):
356
+ return SiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
357
+
358
+ def SiT_S_4(**kwargs):
359
+ return SiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs)
360
+
361
+ def SiT_S_8(**kwargs):
362
+ return SiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
363
+
364
+
365
+ SiT_models = {
366
+ 'SiT-XL/2': SiT_XL_2, 'SiT-XL/4': SiT_XL_4, 'SiT-XL/8': SiT_XL_8,
367
+ 'SiT-L/2': SiT_L_2, 'SiT-L/4': SiT_L_4, 'SiT-L/8': SiT_L_8,
368
+ 'SiT-B/2': SiT_B_2, 'SiT-B/4': SiT_B_4, 'SiT-B/8': SiT_B_8,
369
+ 'SiT-S/2': SiT_S_2, 'SiT-S/4': SiT_S_4, 'SiT-S/8': SiT_S_8,
370
+ }
SiT_back/SiT_clean/run.sh ADDED
File without changes
SiT_back/SiT_clean/sample.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This source code is licensed under the license found in the
2
+ # LICENSE file in the root directory of this source tree.
3
+
4
+ """
5
+ Sample new images from a pre-trained SiT.
6
+ """
7
+ import torch
8
+ torch.backends.cuda.matmul.allow_tf32 = True
9
+ torch.backends.cudnn.allow_tf32 = True
10
+ from torchvision.utils import save_image
11
+ from diffusers.models import AutoencoderKL
12
+ from download import find_model
13
+ from models import SiT_models
14
+ from train_utils import parse_ode_args, parse_sde_args, parse_transport_args
15
+ from transport import create_transport, Sampler
16
+ import argparse
17
+ import sys
18
+ from time import time
19
+
20
+
21
+ def main(mode, args):
22
+ # Setup PyTorch:
23
+ torch.manual_seed(args.seed)
24
+ torch.set_grad_enabled(False)
25
+ device = "cuda" if torch.cuda.is_available() else "cpu"
26
+
27
+ if args.ckpt is None:
28
+ assert args.model == "SiT-XL/2", "Only SiT-XL/2 models are available for auto-download."
29
+ assert args.image_size in [256, 512]
30
+ assert args.num_classes == 1000
31
+ assert args.image_size == 256, "512x512 models are not yet available for auto-download." # remove this line when 512x512 models are available
32
+ learn_sigma = args.image_size == 256
33
+ else:
34
+ learn_sigma = False
35
+
36
+ # Load model:
37
+ latent_size = args.image_size // 8
38
+ model = SiT_models[args.model](
39
+ input_size=latent_size,
40
+ num_classes=args.num_classes,
41
+ learn_sigma=learn_sigma,
42
+ ).to(device)
43
+ # Auto-download a pre-trained model or load a custom SiT checkpoint from train.py:
44
+ ckpt_path = args.ckpt or f"SiT-XL-2-{args.image_size}x{args.image_size}.pt"
45
+ state_dict = find_model(ckpt_path)
46
+ model.load_state_dict(state_dict)
47
+ model.eval() # important!
48
+ transport = create_transport(
49
+ args.path_type,
50
+ args.prediction,
51
+ args.loss_weight,
52
+ args.train_eps,
53
+ args.sample_eps
54
+ )
55
+ sampler = Sampler(transport)
56
+ if mode == "ODE":
57
+ if args.likelihood:
58
+ assert args.cfg_scale == 1, "Likelihood is incompatible with guidance"
59
+ sample_fn = sampler.sample_ode_likelihood(
60
+ sampling_method=args.sampling_method,
61
+ num_steps=args.num_sampling_steps,
62
+ atol=args.atol,
63
+ rtol=args.rtol,
64
+ )
65
+ else:
66
+ sample_fn = sampler.sample_ode(
67
+ sampling_method=args.sampling_method,
68
+ num_steps=args.num_sampling_steps,
69
+ atol=args.atol,
70
+ rtol=args.rtol,
71
+ reverse=args.reverse
72
+ )
73
+
74
+ elif mode == "SDE":
75
+ sample_fn = sampler.sample_sde(
76
+ sampling_method=args.sampling_method,
77
+ diffusion_form=args.diffusion_form,
78
+ diffusion_norm=args.diffusion_norm,
79
+ last_step=args.last_step,
80
+ last_step_size=args.last_step_size,
81
+ num_steps=args.num_sampling_steps,
82
+ )
83
+
84
+
85
+ vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device)
86
+
87
+ # Labels to condition the model with (feel free to change):
88
+ class_labels = [207, 360, 387, 974, 88, 979, 417, 279]
89
+
90
+ # Create sampling noise:
91
+ n = len(class_labels)
92
+ z = torch.randn(n, 4, latent_size, latent_size, device=device)
93
+ y = torch.tensor(class_labels, device=device)
94
+
95
+ # Setup classifier-free guidance:
96
+ z = torch.cat([z, z], 0)
97
+ y_null = torch.tensor([1000] * n, device=device)
98
+ y = torch.cat([y, y_null], 0)
99
+ model_kwargs = dict(y=y, cfg_scale=args.cfg_scale)
100
+
101
+ # Sample images:
102
+ start_time = time()
103
+ samples = sample_fn(z, model.forward_with_cfg, **model_kwargs)[-1]
104
+ samples, _ = samples.chunk(2, dim=0) # Remove null class samples
105
+ samples = vae.decode(samples / 0.18215).sample
106
+ print(f"Sampling took {time() - start_time:.2f} seconds.")
107
+
108
+ # Save and display images:
109
+ save_image(samples, "sample.png", nrow=4, normalize=True, value_range=(-1, 1))
110
+
111
+
112
+ if __name__ == "__main__":
113
+ parser = argparse.ArgumentParser()
114
+
115
+ if len(sys.argv) < 2:
116
+ print("Usage: program.py <mode> [options]")
117
+ sys.exit(1)
118
+
119
+ mode = sys.argv[1]
120
+
121
+ assert mode[:2] != "--", "Usage: program.py <mode> [options]"
122
+ assert mode in ["ODE", "SDE"], "Invalid mode. Please choose 'ODE' or 'SDE'"
123
+
124
+ parser.add_argument("--model", type=str, choices=list(SiT_models.keys()), default="SiT-XL/2")
125
+ parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="mse")
126
+ parser.add_argument("--image-size", type=int, choices=[256, 512], default=256)
127
+ parser.add_argument("--num-classes", type=int, default=1000)
128
+ parser.add_argument("--cfg-scale", type=float, default=4.0)
129
+ parser.add_argument("--num-sampling-steps", type=int, default=250)
130
+ parser.add_argument("--seed", type=int, default=0)
131
+ parser.add_argument("--ckpt", type=str, default=None,
132
+ help="Optional path to a SiT checkpoint (default: auto-download a pre-trained SiT-XL/2 model).")
133
+
134
+
135
+ parse_transport_args(parser)
136
+ if mode == "ODE":
137
+ parse_ode_args(parser)
138
+ # Further processing for ODE
139
+ elif mode == "SDE":
140
+ parse_sde_args(parser)
141
+ # Further processing for SDE
142
+
143
+ args = parser.parse_known_args()[0]
144
+ main(mode, args)
SiT_back/SiT_clean/sample_ddp.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This source code is licensed under the license found in the
2
+ # LICENSE file in the root directory of this source tree.
3
+
4
+ """
5
+ Samples a large number of images from a pre-trained SiT model using DDP.
6
+ Subsequently saves a .npz file that can be used to compute FID and other
7
+ evaluation metrics via the ADM repo: https://github.com/openai/guided-diffusion/tree/main/evaluations
8
+
9
+ For a simple single-GPU/CPU sampling script, see sample.py.
10
+ """
11
+ import torch
12
+ import torch.distributed as dist
13
+ from models import SiT_models
14
+ from download import find_model
15
+ from transport import create_transport, Sampler
16
+ from diffusers.models import AutoencoderKL
17
+ from train_utils import parse_ode_args, parse_sde_args, parse_transport_args
18
+ from tqdm import tqdm
19
+ import os
20
+ from PIL import Image
21
+ import numpy as np
22
+ import math
23
+ import argparse
24
+ import sys
25
+
26
+
27
+ def create_npz_from_sample_folder(sample_dir, num=50_000):
28
+ """
29
+ Builds a single .npz file from a folder of .png samples.
30
+ """
31
+ samples = []
32
+ for i in tqdm(range(num), desc="Building .npz file from samples"):
33
+ sample_pil = Image.open(f"{sample_dir}/{i:06d}.png")
34
+ sample_np = np.asarray(sample_pil).astype(np.uint8)
35
+ samples.append(sample_np)
36
+ samples = np.stack(samples)
37
+ assert samples.shape == (num, samples.shape[1], samples.shape[2], 3)
38
+ npz_path = f"{sample_dir}.npz"
39
+ np.savez(npz_path, arr_0=samples)
40
+ print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
41
+ return npz_path
42
+
43
+
44
+ def main(mode, args):
45
+ """
46
+ Run sampling.
47
+ """
48
+ torch.backends.cuda.matmul.allow_tf32 = args.tf32 # True: fast but may lead to some small numerical differences
49
+ assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
50
+ torch.set_grad_enabled(False)
51
+
52
+ # Setup DDP:
53
+ dist.init_process_group("nccl")
54
+ rank = dist.get_rank()
55
+ device = rank % torch.cuda.device_count()
56
+ seed = args.global_seed * dist.get_world_size() + rank
57
+ torch.manual_seed(seed)
58
+ torch.cuda.set_device(device)
59
+ print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
60
+
61
+ if args.ckpt is None:
62
+ assert args.model == "SiT-XL/2", "Only SiT-XL/2 models are available for auto-download."
63
+ assert args.image_size in [256, 512]
64
+ assert args.num_classes == 1000
65
+ assert args.image_size == 256, "512x512 models are not yet available for auto-download." # remove this line when 512x512 models are available
66
+ learn_sigma = args.image_size == 256
67
+ else:
68
+ learn_sigma = False
69
+
70
+ # Load model:
71
+ latent_size = args.image_size // 8
72
+ model = SiT_models[args.model](
73
+ input_size=latent_size,
74
+ num_classes=args.num_classes,
75
+ learn_sigma=learn_sigma,
76
+ ).to(device)
77
+ # Auto-download a pre-trained model or load a custom SiT checkpoint from train.py:
78
+ ckpt_path = args.ckpt or f"SiT-XL-2-{args.image_size}x{args.image_size}.pt"
79
+ state_dict = find_model(ckpt_path)
80
+ model.load_state_dict(state_dict)
81
+ model.eval() # important!
82
+
83
+
84
+ transport = create_transport(
85
+ args.path_type,
86
+ args.prediction,
87
+ args.loss_weight,
88
+ args.train_eps,
89
+ args.sample_eps
90
+ )
91
+ sampler = Sampler(transport)
92
+ if mode == "ODE":
93
+ if args.likelihood:
94
+ assert args.cfg_scale == 1, "Likelihood is incompatible with guidance"
95
+ sample_fn = sampler.sample_ode_likelihood(
96
+ sampling_method=args.sampling_method,
97
+ num_steps=args.num_sampling_steps,
98
+ atol=args.atol,
99
+ rtol=args.rtol,
100
+ )
101
+ else:
102
+ sample_fn = sampler.sample_ode(
103
+ sampling_method=args.sampling_method,
104
+ num_steps=args.num_sampling_steps,
105
+ atol=args.atol,
106
+ rtol=args.rtol,
107
+ reverse=args.reverse
108
+ )
109
+ elif mode == "SDE":
110
+ sample_fn = sampler.sample_sde(
111
+ sampling_method=args.sampling_method,
112
+ diffusion_form=args.diffusion_form,
113
+ diffusion_norm=args.diffusion_norm,
114
+ last_step=args.last_step,
115
+ last_step_size=args.last_step_size,
116
+ num_steps=args.num_sampling_steps,
117
+ )
118
+ vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device)
119
+ assert args.cfg_scale >= 1.0, "In almost all cases, cfg_scale be >= 1.0"
120
+ using_cfg = args.cfg_scale > 1.0
121
+
122
+ # Create folder to save samples:
123
+ model_string_name = args.model.replace("/", "-")
124
+ ckpt_string_name = os.path.basename(args.ckpt).replace(".pt", "") if args.ckpt else "pretrained"
125
+ if mode == "ODE":
126
+ folder_name = f"{model_string_name}-{ckpt_string_name}-" \
127
+ f"cfg-{args.cfg_scale}-{args.per_proc_batch_size}-"\
128
+ f"{mode}-{args.num_sampling_steps}-{args.sampling_method}"
129
+ elif mode == "SDE":
130
+ folder_name = f"{model_string_name}-{ckpt_string_name}-" \
131
+ f"cfg-{args.cfg_scale}-{args.per_proc_batch_size}-"\
132
+ f"{mode}-{args.num_sampling_steps}-{args.sampling_method}-"\
133
+ f"{args.diffusion_form}-{args.last_step}-{args.last_step_size}"
134
+ sample_folder_dir = f"{args.sample_dir}/{folder_name}"
135
+ if rank == 0:
136
+ os.makedirs(sample_folder_dir, exist_ok=True)
137
+ print(f"Saving .png samples at {sample_folder_dir}")
138
+ dist.barrier()
139
+
140
+ # Figure out how many samples we need to generate on each GPU and how many iterations we need to run:
141
+ n = args.per_proc_batch_size
142
+ global_batch_size = n * dist.get_world_size()
143
+ # To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples:
144
+ num_samples = len([name for name in os.listdir(sample_folder_dir) if (os.path.isfile(os.path.join(sample_folder_dir, name)) and ".png" in name)])
145
+ total_samples = int(math.ceil(args.num_fid_samples / global_batch_size) * global_batch_size)
146
+ if rank == 0:
147
+ print(f"Total number of images that will be sampled: {total_samples}")
148
+ assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size"
149
+ samples_needed_this_gpu = int(total_samples // dist.get_world_size())
150
+ assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size"
151
+ iterations = int(samples_needed_this_gpu // n)
152
+ done_iterations = int( int(num_samples // dist.get_world_size()) // n)
153
+ pbar = range(iterations)
154
+ pbar = tqdm(pbar) if rank == 0 else pbar
155
+ total = 0
156
+
157
+ for i in pbar:
158
+ # Sample inputs:
159
+ z = torch.randn(n, model.in_channels, latent_size, latent_size, device=device)
160
+ y = torch.randint(0, args.num_classes, (n,), device=device)
161
+
162
+ # Setup classifier-free guidance:
163
+ if using_cfg:
164
+ z = torch.cat([z, z], 0)
165
+ y_null = torch.tensor([1000] * n, device=device)
166
+ y = torch.cat([y, y_null], 0)
167
+ model_kwargs = dict(y=y, cfg_scale=args.cfg_scale)
168
+ model_fn = model.forward_with_cfg
169
+ else:
170
+ model_kwargs = dict(y=y)
171
+ model_fn = model.forward
172
+
173
+ samples = sample_fn(z, model_fn, **model_kwargs)[-1]
174
+ if using_cfg:
175
+ samples, _ = samples.chunk(2, dim=0) # Remove null class samples
176
+
177
+ samples = vae.decode(samples / 0.18215).sample
178
+ samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
179
+
180
+ # Save samples to disk as individual .png files
181
+ for i, sample in enumerate(samples):
182
+ index = i * dist.get_world_size() + rank + total
183
+ Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png")
184
+ total += global_batch_size
185
+ dist.barrier()
186
+
187
+ # Make sure all processes have finished saving their samples before attempting to convert to .npz
188
+ dist.barrier()
189
+ if rank == 0:
190
+ create_npz_from_sample_folder(sample_folder_dir, args.num_fid_samples)
191
+ print("Done.")
192
+ dist.barrier()
193
+ dist.destroy_process_group()
194
+
195
+
196
+ if __name__ == "__main__":
197
+
198
+ parser = argparse.ArgumentParser()
199
+
200
+ if len(sys.argv) < 2:
201
+ print("Usage: program.py <mode> [options]")
202
+ sys.exit(1)
203
+
204
+ mode = sys.argv[1]
205
+
206
+ assert mode[:2] != "--", "Usage: program.py <mode> [options]"
207
+ assert mode in ["ODE", "SDE"], "Invalid mode. Please choose 'ODE' or 'SDE'"
208
+
209
+ parser.add_argument("--model", type=str, choices=list(SiT_models.keys()), default="SiT-XL/2")
210
+ parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema")
211
+ parser.add_argument("--sample-dir", type=str, default="samples")
212
+ parser.add_argument("--per-proc-batch-size", type=int, default=4)
213
+ parser.add_argument("--num-fid-samples", type=int, default=50_000)
214
+ parser.add_argument("--image-size", type=int, choices=[256, 512], default=256)
215
+ parser.add_argument("--num-classes", type=int, default=1000)
216
+ parser.add_argument("--cfg-scale", type=float, default=1.0)
217
+ parser.add_argument("--num-sampling-steps", type=int, default=250)
218
+ parser.add_argument("--global-seed", type=int, default=0)
219
+ parser.add_argument("--tf32", action=argparse.BooleanOptionalAction, default=True,
220
+ help="By default, use TF32 matmuls. This massively accelerates sampling on Ampere GPUs.")
221
+ parser.add_argument("--ckpt", type=str, default=None,
222
+ help="Optional path to a SiT checkpoint (default: auto-download a pre-trained SiT-XL/2 model).")
223
+
224
+ parse_transport_args(parser)
225
+ if mode == "ODE":
226
+ parse_ode_args(parser)
227
+ # Further processing for ODE
228
+ elif mode == "SDE":
229
+ parse_sde_args(parser)
230
+ # Further processing for SDE
231
+
232
+ args = parser.parse_known_args()[0]
233
+ main(mode, args)
SiT_back/SiT_clean/train.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This source code is licensed under the license found in the
2
+ # LICENSE file in the root directory of this source tree.
3
+
4
+ """
5
+ A minimal training script for SiT using PyTorch DDP.
6
+ """
7
+ import torch
8
+ # the first flag below was False when we tested this script but True makes A100 training a lot faster:
9
+ torch.backends.cuda.matmul.allow_tf32 = True
10
+ torch.backends.cudnn.allow_tf32 = True
11
+ import torch.distributed as dist
12
+ from torch.nn.parallel import DistributedDataParallel as DDP
13
+ from torch.utils.data import DataLoader
14
+ from torch.utils.data.distributed import DistributedSampler
15
+ from torchvision.datasets import ImageFolder
16
+ from torchvision import transforms
17
+ import numpy as np
18
+ from collections import OrderedDict
19
+ from PIL import Image
20
+ from copy import deepcopy
21
+ from glob import glob
22
+ from time import time
23
+ import argparse
24
+ import logging
25
+ import os
26
+
27
+ from models import SiT_models
28
+ from download import find_model
29
+ from transport import create_transport, Sampler
30
+ from diffusers.models import AutoencoderKL
31
+ from train_utils import parse_transport_args
32
+
33
+ #################################################################################
34
+ # Training Helper Functions #
35
+ #################################################################################
36
+
37
+ @torch.no_grad()
38
+ def update_ema(ema_model, model, decay=0.9999):
39
+ """
40
+ Step the EMA model towards the current model.
41
+ """
42
+ ema_params = OrderedDict(ema_model.named_parameters())
43
+ model_params = OrderedDict(model.named_parameters())
44
+
45
+ for name, param in model_params.items():
46
+ # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
47
+ ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
48
+
49
+
50
+ def requires_grad(model, flag=True):
51
+ """
52
+ Set requires_grad flag for all parameters in a model.
53
+ """
54
+ for p in model.parameters():
55
+ p.requires_grad = flag
56
+
57
+
58
+ def cleanup():
59
+ """
60
+ End DDP training.
61
+ """
62
+ dist.destroy_process_group()
63
+
64
+
65
+ def create_logger(logging_dir):
66
+ """
67
+ Create a logger that writes to a log file and stdout.
68
+ """
69
+ if dist.get_rank() == 0: # real logger
70
+ logging.basicConfig(
71
+ level=logging.INFO,
72
+ format='[\033[34m%(asctime)s\033[0m] %(message)s',
73
+ datefmt='%Y-%m-%d %H:%M:%S',
74
+ handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
75
+ )
76
+ logger = logging.getLogger(__name__)
77
+ else: # dummy logger (does nothing)
78
+ logger = logging.getLogger(__name__)
79
+ logger.addHandler(logging.NullHandler())
80
+ return logger
81
+
82
+
83
+ def center_crop_arr(pil_image, image_size):
84
+ """
85
+ Center cropping implementation from ADM.
86
+ https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
87
+ """
88
+ while min(*pil_image.size) >= 2 * image_size:
89
+ pil_image = pil_image.resize(
90
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
91
+ )
92
+
93
+ scale = image_size / min(*pil_image.size)
94
+ pil_image = pil_image.resize(
95
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
96
+ )
97
+
98
+ arr = np.array(pil_image)
99
+ crop_y = (arr.shape[0] - image_size) // 2
100
+ crop_x = (arr.shape[1] - image_size) // 2
101
+ return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
102
+
103
+
104
+ #################################################################################
105
+ # Training Loop #
106
+ #################################################################################
107
+
108
+ def main(args):
109
+ """
110
+ Trains a new SiT model.
111
+ """
112
+ assert torch.cuda.is_available(), "Training currently requires at least one GPU."
113
+
114
+ # Setup DDP:
115
+ dist.init_process_group("nccl")
116
+ assert args.global_batch_size % dist.get_world_size() == 0, f"Batch size must be divisible by world size."
117
+ rank = dist.get_rank()
118
+ device = rank % torch.cuda.device_count()
119
+ seed = args.global_seed * dist.get_world_size() + rank
120
+ torch.manual_seed(seed)
121
+ torch.cuda.set_device(device)
122
+ print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
123
+ local_batch_size = int(args.global_batch_size // dist.get_world_size())
124
+
125
+ # Setup an experiment folder:
126
+ if rank == 0:
127
+ os.makedirs(args.results_dir, exist_ok=True) # Make results folder (holds all experiment subfolders)
128
+ experiment_index = len(glob(f"{args.results_dir}/*"))
129
+ model_string_name = args.model.replace("/", "-") # e.g., SiT-XL/2 --> SiT-XL-2 (for naming folders)
130
+ experiment_name = f"{experiment_index:03d}-{model_string_name}-" \
131
+ f"{args.path_type}-{args.prediction}-{args.loss_weight}"
132
+ experiment_dir = f"{args.results_dir}/{experiment_name}" # Create an experiment folder
133
+ checkpoint_dir = f"{experiment_dir}/checkpoints" # Stores saved model checkpoints
134
+ os.makedirs(checkpoint_dir, exist_ok=True)
135
+
136
+ # Create pic directory for saving sample images
137
+ pic_dir = f"{experiment_dir}/pic"
138
+ os.makedirs(pic_dir, exist_ok=True)
139
+
140
+ logger = create_logger(experiment_dir)
141
+ logger.info(f"Experiment directory created at {experiment_dir}")
142
+ logger.info(f"Sample images will be saved to {pic_dir}")
143
+
144
+ else:
145
+ logger = create_logger(None)
146
+
147
+ # Create model:
148
+ assert args.image_size % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)."
149
+ latent_size = args.image_size // 8
150
+ model = SiT_models[args.model](
151
+ input_size=latent_size,
152
+ num_classes=args.num_classes
153
+ )
154
+
155
+ # Note that parameter initialization is done within the SiT constructor
156
+ ema = deepcopy(model).to(device) # Create an EMA of the model for use after training
157
+
158
+ if args.ckpt is not None:
159
+ ckpt_path = args.ckpt
160
+ state_dict = find_model(ckpt_path)
161
+ model.load_state_dict(state_dict["model"])
162
+ ema.load_state_dict(state_dict["ema"])
163
+ opt.load_state_dict(state_dict["opt"])
164
+ args = state_dict["args"]
165
+
166
+ requires_grad(ema, False)
167
+
168
+ model = DDP(model.to(device), device_ids=[device])
169
+ transport = create_transport(
170
+ args.path_type,
171
+ args.prediction,
172
+ args.loss_weight,
173
+ args.train_eps,
174
+ args.sample_eps
175
+ ) # default: velocity;
176
+ transport_sampler = Sampler(transport)
177
+ vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device)
178
+ logger.info(f"SiT Parameters: {sum(p.numel() for p in model.parameters()):,}")
179
+
180
+ # Setup optimizer (we used default Adam betas=(0.9, 0.999) and a constant learning rate of 1e-4 in our paper):
181
+ opt = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0)
182
+
183
+ # Setup data:
184
+ transform = transforms.Compose([
185
+ transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)),
186
+ transforms.RandomHorizontalFlip(),
187
+ transforms.ToTensor(),
188
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
189
+ ])
190
+ dataset = ImageFolder(args.data_path, transform=transform)
191
+ sampler = DistributedSampler(
192
+ dataset,
193
+ num_replicas=dist.get_world_size(),
194
+ rank=rank,
195
+ shuffle=True,
196
+ seed=args.global_seed
197
+ )
198
+ loader = DataLoader(
199
+ dataset,
200
+ batch_size=local_batch_size,
201
+ shuffle=False,
202
+ sampler=sampler,
203
+ num_workers=args.num_workers,
204
+ pin_memory=True,
205
+ drop_last=True
206
+ )
207
+ logger.info(f"Dataset contains {len(dataset):,} images ({args.data_path})")
208
+
209
+ # Prepare models for training:
210
+ update_ema(ema, model.module, decay=0) # Ensure EMA is initialized with synced weights
211
+ model.train() # important! This enables embedding dropout for classifier-free guidance
212
+ ema.eval() # EMA model should always be in eval mode
213
+
214
+ # Variables for monitoring/logging purposes:
215
+ train_steps = 0
216
+ log_steps = 0
217
+ running_loss = 0
218
+ start_time = time()
219
+
220
+ # Labels to condition the model with (feel free to change):
221
+ ys = torch.randint(1000, size=(local_batch_size,), device=device)
222
+ use_cfg = args.cfg_scale > 1.0
223
+ # Create sampling noise:
224
+ n = ys.size(0)
225
+ zs = torch.randn(n, 4, latent_size, latent_size, device=device)
226
+
227
+ # Create fixed sampling noise and conditions for consistent sampling visualization
228
+ fixed_ys = torch.randint(1000, size=(16,), device=device) # Fixed labels for sampling
229
+ fixed_zs = torch.randn(16, 4, latent_size, latent_size, device=device) # Fixed noise for sampling
230
+
231
+ # Setup classifier-free guidance:
232
+ if use_cfg:
233
+ zs = torch.cat([zs, zs], 0)
234
+ y_null = torch.tensor([1000] * n, device=device)
235
+ ys = torch.cat([ys, y_null], 0)
236
+ sample_model_kwargs = dict(y=ys, cfg_scale=args.cfg_scale)
237
+ model_fn = ema.forward_with_cfg
238
+ else:
239
+ sample_model_kwargs = dict(y=ys)
240
+ model_fn = ema.forward
241
+
242
+ # Setup fixed classifier-free guidance for sampling:
243
+ if args.cfg_scale > 1.0:
244
+ fixed_zs = torch.cat([fixed_zs, fixed_zs], 0)
245
+ fixed_y_null = torch.tensor([1000] * 16, device=device)
246
+ fixed_ys = torch.cat([fixed_ys, fixed_y_null], 0)
247
+ fixed_sample_model_kwargs = dict(y=fixed_ys, cfg_scale=args.cfg_scale)
248
+ else:
249
+ fixed_sample_model_kwargs = dict(y=fixed_ys)
250
+
251
+ logger.info(f"Training for {args.epochs} epochs...")
252
+ for epoch in range(args.epochs):
253
+ sampler.set_epoch(epoch)
254
+ logger.info(f"Beginning epoch {epoch}...")
255
+ for x, y in loader:
256
+ x = x.to(device)
257
+ y = y.to(device)
258
+ with torch.no_grad():
259
+ # Map input images to latent space + normalize latents:
260
+ x = vae.encode(x).latent_dist.sample().mul_(0.18215)
261
+ model_kwargs = dict(y=y)
262
+ loss_dict = transport.training_losses(model, x, model_kwargs)
263
+ loss = loss_dict["loss"].mean()
264
+ opt.zero_grad()
265
+ loss.backward()
266
+ opt.step()
267
+ update_ema(ema, model.module)
268
+
269
+ # Log loss values:
270
+ running_loss += loss.item()
271
+ log_steps += 1
272
+ train_steps += 1
273
+ if train_steps % args.log_every == 0:
274
+ # Measure training speed:
275
+ torch.cuda.synchronize()
276
+ end_time = time()
277
+ steps_per_sec = log_steps / (end_time - start_time)
278
+ # Reduce loss history over all processes:
279
+ avg_loss = torch.tensor(running_loss / log_steps, device=device)
280
+ dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM)
281
+ avg_loss = avg_loss.item() / dist.get_world_size()
282
+ logger.info(f"(step={train_steps:07d}) Train Loss: {avg_loss:.4f}, Train Steps/Sec: {steps_per_sec:.2f}")
283
+ # Reset monitoring variables:
284
+ running_loss = 0
285
+ log_steps = 0
286
+ start_time = time()
287
+
288
+ # Save SiT checkpoint:
289
+ if train_steps % args.ckpt_every == 0 and train_steps > 0:
290
+ if rank == 0:
291
+ checkpoint = {
292
+ "model": model.module.state_dict(),
293
+ "ema": ema.state_dict(),
294
+ "opt": opt.state_dict(),
295
+ "args": args
296
+ }
297
+ checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pt"
298
+ torch.save(checkpoint, checkpoint_path)
299
+ logger.info(f"Saved checkpoint to {checkpoint_path}")
300
+ dist.barrier()
301
+
302
+ # Save sample images:
303
+ if train_steps % args.sample_every == 0 and train_steps > 0:
304
+ logger.info("Generating EMA samples...")
305
+ sample_fn = transport_sampler.sample_ode() # default to ode sampling
306
+ samples = sample_fn(fixed_zs, model_fn, **fixed_sample_model_kwargs)[-1]
307
+ dist.barrier()
308
+
309
+ if args.cfg_scale > 1.0: #remove null samples
310
+ samples, _ = samples.chunk(2, dim=0)
311
+ samples = vae.decode(samples / 0.18215).sample
312
+
313
+ # Save sample images to pic directory instead of wandb
314
+ if rank == 0:
315
+ # Create a 4x4 grid of images
316
+ # Normalize images from [-1, 1] to [0, 1]
317
+ samples = (samples.clamp(-1, 1) + 1) / 2
318
+ # Convert to PIL Images and arrange in a 4x4 grid
319
+ # Create a blank image for the grid
320
+ grid_size = args.image_size
321
+ grid_image = Image.new('RGB', (4 * grid_size, 4 * grid_size))
322
+
323
+ # Place each sample in the grid
324
+ for i in range(min(16, samples.shape[0])):
325
+ # Convert to PIL Image
326
+ img = samples[i].permute(1, 2, 0).cpu().detach().numpy()
327
+ img = (img * 255).astype(np.uint8)
328
+ pil_img = Image.fromarray(img)
329
+
330
+ # Calculate position in the grid
331
+ row = i // 4
332
+ col = i % 4
333
+ grid_image.paste(pil_img, (col * grid_size, row * grid_size))
334
+
335
+ # Save the grid image
336
+ img_path = f"{pic_dir}/step_{train_steps:07d}_samples_grid.png"
337
+ grid_image.save(img_path)
338
+ logger.info(f"Saved sample images grid to {img_path}")
339
+
340
+ logging.info("Generating EMA samples done.")
341
+
342
+ model.eval() # important! This disables randomized embedding dropout
343
+ # do any sampling/FID calculation/etc. with ema (or model) in eval mode ...
344
+
345
+ logger.info("Done!")
346
+ cleanup()
347
+
348
+
349
+ if __name__ == "__main__":
350
+ # Default args here will train SiT-XL/2 with the hyperparameters we used in our paper (except training iters).
351
+ parser = argparse.ArgumentParser()
352
+ parser.add_argument("--data-path", type=str, default="/gemini/platform/public/hzh/datasets/Imagenet/train/")
353
+ parser.add_argument("--results-dir", type=str, default="results")
354
+ parser.add_argument("--model", type=str, choices=list(SiT_models.keys()), default="SiT-XL/2")
355
+ parser.add_argument("--image-size", type=int, choices=[256, 512], default=256)
356
+ parser.add_argument("--num-classes", type=int, default=1000)
357
+ parser.add_argument("--epochs", type=int, default=140000)
358
+ parser.add_argument("--global-batch-size", type=int, default=256)
359
+ parser.add_argument("--global-seed", type=int, default=0)
360
+ parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema") # Choice doesn't affect training
361
+ parser.add_argument("--num-workers", type=int, default=4)
362
+ parser.add_argument("--log-every", type=int, default=100)
363
+ parser.add_argument("--ckpt-every", type=int, default=10)
364
+ parser.add_argument("--sample-every", type=int, default=10)
365
+ parser.add_argument("--cfg-scale", type=float, default=4.0)
366
+ parser.add_argument("--ckpt", type=str, default=None,
367
+ help="Optional path to a custom SiT checkpoint")
368
+
369
+ parse_transport_args(parser)
370
+ args = parser.parse_args()
371
+ main(args)
SiT_back/SiT_clean/train_utils.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def none_or_str(value):
2
+ if value == 'None':
3
+ return None
4
+ return value
5
+
6
+ def parse_transport_args(parser):
7
+ group = parser.add_argument_group("Transport arguments")
8
+ group.add_argument("--path-type", type=str, default="Linear", choices=["Linear", "GVP", "VP"])
9
+ group.add_argument("--prediction", type=str, default="velocity", choices=["velocity", "score", "noise"])
10
+ group.add_argument("--loss-weight", type=none_or_str, default=None, choices=[None, "velocity", "likelihood"])
11
+ group.add_argument("--sample-eps", type=float)
12
+ group.add_argument("--train-eps", type=float)
13
+
14
+ def parse_ode_args(parser):
15
+ group = parser.add_argument_group("ODE arguments")
16
+ group.add_argument("--sampling-method", type=str, default="dopri5", help="blackbox ODE solver methods; for full list check https://github.com/rtqichen/torchdiffeq")
17
+ group.add_argument("--atol", type=float, default=1e-6, help="Absolute tolerance")
18
+ group.add_argument("--rtol", type=float, default=1e-3, help="Relative tolerance")
19
+ group.add_argument("--reverse", action="store_true")
20
+ group.add_argument("--likelihood", action="store_true")
21
+
22
+ def parse_sde_args(parser):
23
+ group = parser.add_argument_group("SDE arguments")
24
+ group.add_argument("--sampling-method", type=str, default="Euler", choices=["Euler", "Heun"])
25
+ group.add_argument("--diffusion-form", type=str, default="sigma", \
26
+ choices=["constant", "SBDM", "sigma", "linear", "decreasing", "increasing-decreasing"],\
27
+ help="form of diffusion coefficient in the SDE")
28
+ group.add_argument("--diffusion-norm", type=float, default=1.0)
29
+ group.add_argument("--last-step", type=none_or_str, default="Mean", choices=[None, "Mean", "Tweedie", "Euler"],\
30
+ help="form of last step taken in the SDE")
31
+ group.add_argument("--last-step-size", type=float, default=0.04, \
32
+ help="size of the last step taken")
SiT_back/SiT_clean/transport/__init__.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .transport import Transport, ModelType, WeightType, PathType, Sampler
2
+
3
+ def create_transport(
4
+ path_type='Linear',
5
+ prediction="velocity",
6
+ loss_weight=None,
7
+ train_eps=None,
8
+ sample_eps=None,
9
+ ):
10
+ """function for creating Transport object
11
+ **Note**: model prediction defaults to velocity
12
+ Args:
13
+ - path_type: type of path to use; default to linear
14
+ - learn_score: set model prediction to score
15
+ - learn_noise: set model prediction to noise
16
+ - velocity_weighted: weight loss by velocity weight
17
+ - likelihood_weighted: weight loss by likelihood weight
18
+ - train_eps: small epsilon for avoiding instability during training
19
+ - sample_eps: small epsilon for avoiding instability during sampling
20
+ """
21
+
22
+ if prediction == "noise":
23
+ model_type = ModelType.NOISE
24
+ elif prediction == "score":
25
+ model_type = ModelType.SCORE
26
+ else:
27
+ model_type = ModelType.VELOCITY
28
+
29
+ if loss_weight == "velocity":
30
+ loss_type = WeightType.VELOCITY
31
+ elif loss_weight == "likelihood":
32
+ loss_type = WeightType.LIKELIHOOD
33
+ else:
34
+ loss_type = WeightType.NONE
35
+
36
+ path_choice = {
37
+ "Linear": PathType.LINEAR,
38
+ "GVP": PathType.GVP,
39
+ "VP": PathType.VP,
40
+ }
41
+
42
+ path_type = path_choice[path_type]
43
+
44
+ if (path_type in [PathType.VP]):
45
+ train_eps_new = 1e-5 if train_eps is None else train_eps
46
+ sample_eps_new = 1e-3 if train_eps is None else sample_eps
47
+ train_eps, sample_eps = train_eps_new, sample_eps_new
48
+ elif (path_type in [PathType.GVP, PathType.LINEAR] and model_type != ModelType.VELOCITY):
49
+ train_eps_new = 1e-3 if train_eps is None else train_eps
50
+ sample_eps_new = 1e-3 if train_eps is None else sample_eps
51
+ train_eps, sample_eps = train_eps_new, sample_eps_new
52
+ else: # velocity & [GVP, LINEAR] is stable everywhere
53
+ train_eps = 0
54
+ sample_eps = 0
55
+
56
+ # create flow state
57
+ state = Transport(
58
+ model_type=model_type,
59
+ path_type=path_type,
60
+ loss_type=loss_type,
61
+ train_eps=train_eps,
62
+ sample_eps=sample_eps,
63
+ )
64
+
65
+ return state
SiT_back/SiT_clean/transport/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (2.19 kB). View file
 
SiT_back/SiT_clean/transport/__pycache__/integrators.cpython-312.pyc ADDED
Binary file (6.21 kB). View file
 
SiT_back/SiT_clean/transport/__pycache__/path.cpython-312.pyc ADDED
Binary file (11.3 kB). View file
 
SiT_back/SiT_clean/transport/__pycache__/transport.cpython-312.pyc ADDED
Binary file (20.4 kB). View file
 
SiT_back/SiT_clean/transport/__pycache__/utils.cpython-312.pyc ADDED
Binary file (1.87 kB). View file
 
SiT_back/SiT_clean/transport/integrators.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch as th
3
+ import torch.nn as nn
4
+ from torchdiffeq import odeint
5
+ from functools import partial
6
+ from tqdm import tqdm
7
+
8
+ class sde:
9
+ """SDE solver class"""
10
+ def __init__(
11
+ self,
12
+ drift,
13
+ diffusion,
14
+ *,
15
+ t0,
16
+ t1,
17
+ num_steps,
18
+ sampler_type,
19
+ ):
20
+ assert t0 < t1, "SDE sampler has to be in forward time"
21
+
22
+ self.num_timesteps = num_steps
23
+ self.t = th.linspace(t0, t1, num_steps)
24
+ self.dt = self.t[1] - self.t[0]
25
+ self.drift = drift
26
+ self.diffusion = diffusion
27
+ self.sampler_type = sampler_type
28
+
29
+ def __Euler_Maruyama_step(self, x, mean_x, t, model, **model_kwargs):
30
+ w_cur = th.randn(x.size()).to(x)
31
+ t = th.ones(x.size(0)).to(x) * t
32
+ dw = w_cur * th.sqrt(self.dt)
33
+ drift = self.drift(x, t, model, **model_kwargs)
34
+ diffusion = self.diffusion(x, t)
35
+ mean_x = x + drift * self.dt
36
+ x = mean_x + th.sqrt(2 * diffusion) * dw
37
+ return x, mean_x
38
+
39
+ def __Heun_step(self, x, _, t, model, **model_kwargs):
40
+ w_cur = th.randn(x.size()).to(x)
41
+ dw = w_cur * th.sqrt(self.dt)
42
+ t_cur = th.ones(x.size(0)).to(x) * t
43
+ diffusion = self.diffusion(x, t_cur)
44
+ xhat = x + th.sqrt(2 * diffusion) * dw
45
+ K1 = self.drift(xhat, t_cur, model, **model_kwargs)
46
+ xp = xhat + self.dt * K1
47
+ K2 = self.drift(xp, t_cur + self.dt, model, **model_kwargs)
48
+ return xhat + 0.5 * self.dt * (K1 + K2), xhat # at last time point we do not perform the heun step
49
+
50
+ def __forward_fn(self):
51
+ """TODO: generalize here by adding all private functions ending with steps to it"""
52
+ sampler_dict = {
53
+ "Euler": self.__Euler_Maruyama_step,
54
+ "Heun": self.__Heun_step,
55
+ }
56
+
57
+ try:
58
+ sampler = sampler_dict[self.sampler_type]
59
+ except:
60
+ raise NotImplementedError("Smapler type not implemented.")
61
+
62
+ return sampler
63
+
64
+ def sample(self, init, model, **model_kwargs):
65
+ """forward loop of sde"""
66
+ x = init
67
+ mean_x = init
68
+ samples = []
69
+ sampler = self.__forward_fn()
70
+ for ti in self.t[:-1]:
71
+ with th.no_grad():
72
+ x, mean_x = sampler(x, mean_x, ti, model, **model_kwargs)
73
+ samples.append(x)
74
+
75
+ return samples
76
+
77
+ class ode:
78
+ """ODE solver class"""
79
+ def __init__(
80
+ self,
81
+ drift,
82
+ *,
83
+ t0,
84
+ t1,
85
+ sampler_type,
86
+ num_steps,
87
+ atol,
88
+ rtol,
89
+ ):
90
+ self.drift = drift
91
+ self.t = th.linspace(t0, t1, num_steps)
92
+ self.atol = atol
93
+ self.rtol = rtol
94
+ self.sampler_type = sampler_type
95
+
96
+ def sample(self, x, model, **model_kwargs):
97
+
98
+ device = x[0].device if isinstance(x, tuple) else x.device
99
+ def _fn(t, x):
100
+ t = th.ones(x[0].size(0)).to(device) * t if isinstance(x, tuple) else th.ones(x.size(0)).to(device) * t
101
+ model_output = self.drift(x, t, model, **model_kwargs)
102
+ return model_output
103
+
104
+ t = self.t.to(device)
105
+ atol = [self.atol] * len(x) if isinstance(x, tuple) else [self.atol]
106
+ rtol = [self.rtol] * len(x) if isinstance(x, tuple) else [self.rtol]
107
+ samples = odeint(
108
+ _fn,
109
+ x,
110
+ t,
111
+ method=self.sampler_type,
112
+ atol=atol,
113
+ rtol=rtol
114
+ )
115
+ return samples
SiT_back/SiT_clean/transport/path.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch as th
2
+ import numpy as np
3
+ from functools import partial
4
+
5
+ def expand_t_like_x(t, x):
6
+ """Function to reshape time t to broadcastable dimension of x
7
+ Args:
8
+ t: [batch_dim,], time vector
9
+ x: [batch_dim,...], data point
10
+ """
11
+ dims = [1] * (len(x.size()) - 1)
12
+ t = t.view(t.size(0), *dims)
13
+ return t
14
+
15
+
16
+ #################### Coupling Plans ####################
17
+
18
+ class ICPlan:
19
+ """Linear Coupling Plan"""
20
+ def __init__(self, sigma=0.0):
21
+ self.sigma = sigma
22
+
23
+ def compute_alpha_t(self, t):
24
+ """Compute the data coefficient along the path"""
25
+ return t, 1
26
+
27
+ def compute_sigma_t(self, t):
28
+ """Compute the noise coefficient along the path"""
29
+ return 1 - t, -1
30
+
31
+ def compute_d_alpha_alpha_ratio_t(self, t):
32
+ """Compute the ratio between d_alpha and alpha"""
33
+ return 1 / t
34
+
35
+ def compute_drift(self, x, t):
36
+ """We always output sde according to score parametrization; """
37
+ t = expand_t_like_x(t, x)
38
+ alpha_ratio = self.compute_d_alpha_alpha_ratio_t(t)
39
+ sigma_t, d_sigma_t = self.compute_sigma_t(t)
40
+ drift = alpha_ratio * x
41
+ diffusion = alpha_ratio * (sigma_t ** 2) - sigma_t * d_sigma_t
42
+
43
+ return -drift, diffusion
44
+
45
+ def compute_diffusion(self, x, t, form="constant", norm=1.0):
46
+ """Compute the diffusion term of the SDE
47
+ Args:
48
+ x: [batch_dim, ...], data point
49
+ t: [batch_dim,], time vector
50
+ form: str, form of the diffusion term
51
+ norm: float, norm of the diffusion term
52
+ """
53
+ t = expand_t_like_x(t, x)
54
+ choices = {
55
+ "constant": norm,
56
+ "SBDM": norm * self.compute_drift(x, t)[1],
57
+ "sigma": norm * self.compute_sigma_t(t)[0],
58
+ "linear": norm * (1 - t),
59
+ "decreasing": 0.25 * (norm * th.cos(np.pi * t) + 1) ** 2,
60
+ "inccreasing-decreasing": norm * th.sin(np.pi * t) ** 2,
61
+ }
62
+
63
+ try:
64
+ diffusion = choices[form]
65
+ except KeyError:
66
+ raise NotImplementedError(f"Diffusion form {form} not implemented")
67
+
68
+ return diffusion
69
+
70
+ def get_score_from_velocity(self, velocity, x, t):
71
+ """Wrapper function: transfrom velocity prediction model to score
72
+ Args:
73
+ velocity: [batch_dim, ...] shaped tensor; velocity model output
74
+ x: [batch_dim, ...] shaped tensor; x_t data point
75
+ t: [batch_dim,] time tensor
76
+ """
77
+ t = expand_t_like_x(t, x)
78
+ alpha_t, d_alpha_t = self.compute_alpha_t(t)
79
+ sigma_t, d_sigma_t = self.compute_sigma_t(t)
80
+ mean = x
81
+ reverse_alpha_ratio = alpha_t / d_alpha_t
82
+ var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t
83
+ score = (reverse_alpha_ratio * velocity - mean) / var
84
+ return score
85
+
86
+ def get_noise_from_velocity(self, velocity, x, t):
87
+ """Wrapper function: transfrom velocity prediction model to denoiser
88
+ Args:
89
+ velocity: [batch_dim, ...] shaped tensor; velocity model output
90
+ x: [batch_dim, ...] shaped tensor; x_t data point
91
+ t: [batch_dim,] time tensor
92
+ """
93
+ t = expand_t_like_x(t, x)
94
+ alpha_t, d_alpha_t = self.compute_alpha_t(t)
95
+ sigma_t, d_sigma_t = self.compute_sigma_t(t)
96
+ mean = x
97
+ reverse_alpha_ratio = alpha_t / d_alpha_t
98
+ var = reverse_alpha_ratio * d_sigma_t - sigma_t
99
+ noise = (reverse_alpha_ratio * velocity - mean) / var
100
+ return noise
101
+
102
+ def get_velocity_from_score(self, score, x, t):
103
+ """Wrapper function: transfrom score prediction model to velocity
104
+ Args:
105
+ score: [batch_dim, ...] shaped tensor; score model output
106
+ x: [batch_dim, ...] shaped tensor; x_t data point
107
+ t: [batch_dim,] time tensor
108
+ """
109
+ t = expand_t_like_x(t, x)
110
+ drift, var = self.compute_drift(x, t)
111
+ velocity = var * score - drift
112
+ return velocity
113
+
114
+ def compute_mu_t(self, t, x0, x1):
115
+ """Compute the mean of time-dependent density p_t"""
116
+ t = expand_t_like_x(t, x1)
117
+ alpha_t, _ = self.compute_alpha_t(t)
118
+ sigma_t, _ = self.compute_sigma_t(t)
119
+ return alpha_t * x1 + sigma_t * x0
120
+
121
+ def compute_xt(self, t, x0, x1):
122
+ """Sample xt from time-dependent density p_t; rng is required"""
123
+ xt = self.compute_mu_t(t, x0, x1)
124
+ return xt
125
+
126
+ def compute_ut(self, t, x0, x1, xt):
127
+ """Compute the vector field corresponding to p_t"""
128
+ t = expand_t_like_x(t, x1)
129
+ _, d_alpha_t = self.compute_alpha_t(t)
130
+ _, d_sigma_t = self.compute_sigma_t(t)
131
+ return d_alpha_t * x1 + d_sigma_t * x0
132
+
133
+ def plan(self, t, x0, x1):
134
+ xt = self.compute_xt(t, x0, x1)
135
+ ut = self.compute_ut(t, x0, x1, xt)
136
+ return t, xt, ut
137
+
138
+
139
+ class VPCPlan(ICPlan):
140
+ """class for VP path flow matching"""
141
+
142
+ def __init__(self, sigma_min=0.1, sigma_max=20.0):
143
+ self.sigma_min = sigma_min
144
+ self.sigma_max = sigma_max
145
+ self.log_mean_coeff = lambda t: -0.25 * ((1 - t) ** 2) * (self.sigma_max - self.sigma_min) - 0.5 * (1 - t) * self.sigma_min
146
+ self.d_log_mean_coeff = lambda t: 0.5 * (1 - t) * (self.sigma_max - self.sigma_min) + 0.5 * self.sigma_min
147
+
148
+
149
+ def compute_alpha_t(self, t):
150
+ """Compute coefficient of x1"""
151
+ alpha_t = self.log_mean_coeff(t)
152
+ alpha_t = th.exp(alpha_t)
153
+ d_alpha_t = alpha_t * self.d_log_mean_coeff(t)
154
+ return alpha_t, d_alpha_t
155
+
156
+ def compute_sigma_t(self, t):
157
+ """Compute coefficient of x0"""
158
+ p_sigma_t = 2 * self.log_mean_coeff(t)
159
+ sigma_t = th.sqrt(1 - th.exp(p_sigma_t))
160
+ d_sigma_t = th.exp(p_sigma_t) * (2 * self.d_log_mean_coeff(t)) / (-2 * sigma_t)
161
+ return sigma_t, d_sigma_t
162
+
163
+ def compute_d_alpha_alpha_ratio_t(self, t):
164
+ """Special purposed function for computing numerical stabled d_alpha_t / alpha_t"""
165
+ return self.d_log_mean_coeff(t)
166
+
167
+ def compute_drift(self, x, t):
168
+ """Compute the drift term of the SDE"""
169
+ t = expand_t_like_x(t, x)
170
+ beta_t = self.sigma_min + (1 - t) * (self.sigma_max - self.sigma_min)
171
+ return -0.5 * beta_t * x, beta_t / 2
172
+
173
+
174
+ class GVPCPlan(ICPlan):
175
+ def __init__(self, sigma=0.0):
176
+ super().__init__(sigma)
177
+
178
+ def compute_alpha_t(self, t):
179
+ """Compute coefficient of x1"""
180
+ alpha_t = th.sin(t * np.pi / 2)
181
+ d_alpha_t = np.pi / 2 * th.cos(t * np.pi / 2)
182
+ return alpha_t, d_alpha_t
183
+
184
+ def compute_sigma_t(self, t):
185
+ """Compute coefficient of x0"""
186
+ sigma_t = th.cos(t * np.pi / 2)
187
+ d_sigma_t = -np.pi / 2 * th.sin(t * np.pi / 2)
188
+ return sigma_t, d_sigma_t
189
+
190
+ def compute_d_alpha_alpha_ratio_t(self, t):
191
+ """Special purposed function for computing numerical stabled d_alpha_t / alpha_t"""
192
+ return np.pi / (2 * th.tan(t * np.pi / 2))
SiT_back/SiT_clean/transport/transport.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch as th
2
+ import numpy as np
3
+ import logging
4
+
5
+ import enum
6
+
7
+ from . import path
8
+ from .utils import EasyDict, log_state, mean_flat
9
+ from .integrators import ode, sde
10
+
11
+ class ModelType(enum.Enum):
12
+ """
13
+ Which type of output the model predicts.
14
+ """
15
+
16
+ NOISE = enum.auto() # the model predicts epsilon
17
+ SCORE = enum.auto() # the model predicts \nabla \log p(x)
18
+ VELOCITY = enum.auto() # the model predicts v(x)
19
+
20
+ class PathType(enum.Enum):
21
+ """
22
+ Which type of path to use.
23
+ """
24
+
25
+ LINEAR = enum.auto()
26
+ GVP = enum.auto()
27
+ VP = enum.auto()
28
+
29
+ class WeightType(enum.Enum):
30
+ """
31
+ Which type of weighting to use.
32
+ """
33
+
34
+ NONE = enum.auto()
35
+ VELOCITY = enum.auto()
36
+ LIKELIHOOD = enum.auto()
37
+
38
+
39
+ class Transport:
40
+
41
+ def __init__(
42
+ self,
43
+ *,
44
+ model_type,
45
+ path_type,
46
+ loss_type,
47
+ train_eps,
48
+ sample_eps,
49
+ ):
50
+ path_options = {
51
+ PathType.LINEAR: path.ICPlan,
52
+ PathType.GVP: path.GVPCPlan,
53
+ PathType.VP: path.VPCPlan,
54
+ }
55
+
56
+ self.loss_type = loss_type
57
+ self.model_type = model_type
58
+ self.path_sampler = path_options[path_type]()
59
+ self.train_eps = train_eps
60
+ self.sample_eps = sample_eps
61
+
62
+ def prior_logp(self, z):
63
+ '''
64
+ Standard multivariate normal prior
65
+ Assume z is batched
66
+ '''
67
+ shape = th.tensor(z.size())
68
+ N = th.prod(shape[1:])
69
+ _fn = lambda x: -N / 2. * np.log(2 * np.pi) - th.sum(x ** 2) / 2.
70
+ return th.vmap(_fn)(z)
71
+
72
+
73
+ def check_interval(
74
+ self,
75
+ train_eps,
76
+ sample_eps,
77
+ *,
78
+ diffusion_form="SBDM",
79
+ sde=False,
80
+ reverse=False,
81
+ eval=False,
82
+ last_step_size=0.0,
83
+ ):
84
+ t0 = 0
85
+ t1 = 1
86
+ eps = train_eps if not eval else sample_eps
87
+ if (type(self.path_sampler) in [path.VPCPlan]):
88
+
89
+ t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size
90
+
91
+ elif (type(self.path_sampler) in [path.ICPlan, path.GVPCPlan]) \
92
+ and (self.model_type != ModelType.VELOCITY or sde): # avoid numerical issue by taking a first semi-implicit step
93
+
94
+ t0 = eps if (diffusion_form == "SBDM" and sde) or self.model_type != ModelType.VELOCITY else 0
95
+ t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size
96
+
97
+ if reverse:
98
+ t0, t1 = 1 - t0, 1 - t1
99
+
100
+ return t0, t1
101
+
102
+
103
+ def sample(self, x1):
104
+ """Sampling x0 & t based on shape of x1 (if needed)
105
+ Args:
106
+ x1 - data point; [batch, *dim]
107
+ """
108
+
109
+ x0 = th.randn_like(x1)
110
+ t0, t1 = self.check_interval(self.train_eps, self.sample_eps)
111
+ t = th.rand((x1.shape[0],)) * (t1 - t0) + t0
112
+ t = t.to(x1)
113
+ return t, x0, x1
114
+
115
+
116
+ def training_losses(
117
+ self,
118
+ model,
119
+ x1,
120
+ model_kwargs=None
121
+ ):
122
+ """Loss for training the score model
123
+ Args:
124
+ - model: backbone model; could be score, noise, or velocity
125
+ - x1: datapoint
126
+ - model_kwargs: additional arguments for the model
127
+ """
128
+ if model_kwargs == None:
129
+ model_kwargs = {}
130
+
131
+ t, x0, x1 = self.sample(x1)
132
+ t, xt, ut = self.path_sampler.plan(t, x0, x1)
133
+ model_output = model(xt, t, **model_kwargs)
134
+ B, *_, C = xt.shape
135
+ assert model_output.size() == (B, *xt.size()[1:-1], C)
136
+
137
+ terms = {}
138
+ terms['pred'] = model_output
139
+ if self.model_type == ModelType.VELOCITY:
140
+ terms['loss'] = mean_flat(((model_output - ut) ** 2))
141
+ else:
142
+ _, drift_var = self.path_sampler.compute_drift(xt, t)
143
+ sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, xt))
144
+ if self.loss_type in [WeightType.VELOCITY]:
145
+ weight = (drift_var / sigma_t) ** 2
146
+ elif self.loss_type in [WeightType.LIKELIHOOD]:
147
+ weight = drift_var / (sigma_t ** 2)
148
+ elif self.loss_type in [WeightType.NONE]:
149
+ weight = 1
150
+ else:
151
+ raise NotImplementedError()
152
+
153
+ if self.model_type == ModelType.NOISE:
154
+ terms['loss'] = mean_flat(weight * ((model_output - x0) ** 2))
155
+ else:
156
+ terms['loss'] = mean_flat(weight * ((model_output * sigma_t + x0) ** 2))
157
+
158
+ return terms
159
+
160
+
161
+ def get_drift(
162
+ self
163
+ ):
164
+ """member function for obtaining the drift of the probability flow ODE"""
165
+ def score_ode(x, t, model, **model_kwargs):
166
+ drift_mean, drift_var = self.path_sampler.compute_drift(x, t)
167
+ model_output = model(x, t, **model_kwargs)
168
+ return (-drift_mean + drift_var * model_output) # by change of variable
169
+
170
+ def noise_ode(x, t, model, **model_kwargs):
171
+ drift_mean, drift_var = self.path_sampler.compute_drift(x, t)
172
+ sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))
173
+ model_output = model(x, t, **model_kwargs)
174
+ score = model_output / -sigma_t
175
+ return (-drift_mean + drift_var * score)
176
+
177
+ def velocity_ode(x, t, model, **model_kwargs):
178
+ model_output = model(x, t, **model_kwargs)
179
+ return model_output
180
+
181
+ if self.model_type == ModelType.NOISE:
182
+ drift_fn = noise_ode
183
+ elif self.model_type == ModelType.SCORE:
184
+ drift_fn = score_ode
185
+ else:
186
+ drift_fn = velocity_ode
187
+
188
+ def body_fn(x, t, model, **model_kwargs):
189
+ model_output = drift_fn(x, t, model, **model_kwargs)
190
+ assert model_output.shape == x.shape, "Output shape from ODE solver must match input shape"
191
+ return model_output
192
+
193
+ return body_fn
194
+
195
+
196
+ def get_score(
197
+ self,
198
+ ):
199
+ """member function for obtaining score of
200
+ x_t = alpha_t * x + sigma_t * eps"""
201
+ if self.model_type == ModelType.NOISE:
202
+ score_fn = lambda x, t, model, **kwargs: model(x, t, **kwargs) / -self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))[0]
203
+ elif self.model_type == ModelType.SCORE:
204
+ score_fn = lambda x, t, model, **kwagrs: model(x, t, **kwagrs)
205
+ elif self.model_type == ModelType.VELOCITY:
206
+ score_fn = lambda x, t, model, **kwargs: self.path_sampler.get_score_from_velocity(model(x, t, **kwargs), x, t)
207
+ else:
208
+ raise NotImplementedError()
209
+
210
+ return score_fn
211
+
212
+
213
+ class Sampler:
214
+ """Sampler class for the transport model"""
215
+ def __init__(
216
+ self,
217
+ transport,
218
+ ):
219
+ """Constructor for a general sampler; supporting different sampling methods
220
+ Args:
221
+ - transport: an tranport object specify model prediction & interpolant type
222
+ """
223
+
224
+ self.transport = transport
225
+ self.drift = self.transport.get_drift()
226
+ self.score = self.transport.get_score()
227
+
228
+ def __get_sde_diffusion_and_drift(
229
+ self,
230
+ *,
231
+ diffusion_form="SBDM",
232
+ diffusion_norm=1.0,
233
+ ):
234
+
235
+ def diffusion_fn(x, t):
236
+ diffusion = self.transport.path_sampler.compute_diffusion(x, t, form=diffusion_form, norm=diffusion_norm)
237
+ return diffusion
238
+
239
+ sde_drift = \
240
+ lambda x, t, model, **kwargs: \
241
+ self.drift(x, t, model, **kwargs) + diffusion_fn(x, t) * self.score(x, t, model, **kwargs)
242
+
243
+ sde_diffusion = diffusion_fn
244
+
245
+ return sde_drift, sde_diffusion
246
+
247
+ def __get_last_step(
248
+ self,
249
+ sde_drift,
250
+ *,
251
+ last_step,
252
+ last_step_size,
253
+ ):
254
+ """Get the last step function of the SDE solver"""
255
+
256
+ if last_step is None:
257
+ last_step_fn = \
258
+ lambda x, t, model, **model_kwargs: \
259
+ x
260
+ elif last_step == "Mean":
261
+ last_step_fn = \
262
+ lambda x, t, model, **model_kwargs: \
263
+ x + sde_drift(x, t, model, **model_kwargs) * last_step_size
264
+ elif last_step == "Tweedie":
265
+ alpha = self.transport.path_sampler.compute_alpha_t # simple aliasing; the original name was too long
266
+ sigma = self.transport.path_sampler.compute_sigma_t
267
+ last_step_fn = \
268
+ lambda x, t, model, **model_kwargs: \
269
+ x / alpha(t)[0][0] + (sigma(t)[0][0] ** 2) / alpha(t)[0][0] * self.score(x, t, model, **model_kwargs)
270
+ elif last_step == "Euler":
271
+ last_step_fn = \
272
+ lambda x, t, model, **model_kwargs: \
273
+ x + self.drift(x, t, model, **model_kwargs) * last_step_size
274
+ else:
275
+ raise NotImplementedError()
276
+
277
+ return last_step_fn
278
+
279
+ def sample_sde(
280
+ self,
281
+ *,
282
+ sampling_method="Euler",
283
+ diffusion_form="SBDM",
284
+ diffusion_norm=1.0,
285
+ last_step="Mean",
286
+ last_step_size=0.04,
287
+ num_steps=250,
288
+ ):
289
+ """returns a sampling function with given SDE settings
290
+ Args:
291
+ - sampling_method: type of sampler used in solving the SDE; default to be Euler-Maruyama
292
+ - diffusion_form: function form of diffusion coefficient; default to be matching SBDM
293
+ - diffusion_norm: function magnitude of diffusion coefficient; default to 1
294
+ - last_step: type of the last step; default to identity
295
+ - last_step_size: size of the last step; default to match the stride of 250 steps over [0,1]
296
+ - num_steps: total integration step of SDE
297
+ """
298
+
299
+ if last_step is None:
300
+ last_step_size = 0.0
301
+
302
+ sde_drift, sde_diffusion = self.__get_sde_diffusion_and_drift(
303
+ diffusion_form=diffusion_form,
304
+ diffusion_norm=diffusion_norm,
305
+ )
306
+
307
+ t0, t1 = self.transport.check_interval(
308
+ self.transport.train_eps,
309
+ self.transport.sample_eps,
310
+ diffusion_form=diffusion_form,
311
+ sde=True,
312
+ eval=True,
313
+ reverse=False,
314
+ last_step_size=last_step_size,
315
+ )
316
+
317
+ _sde = sde(
318
+ sde_drift,
319
+ sde_diffusion,
320
+ t0=t0,
321
+ t1=t1,
322
+ num_steps=num_steps,
323
+ sampler_type=sampling_method
324
+ )
325
+
326
+ last_step_fn = self.__get_last_step(sde_drift, last_step=last_step, last_step_size=last_step_size)
327
+
328
+
329
+ def _sample(init, model, **model_kwargs):
330
+ xs = _sde.sample(init, model, **model_kwargs)
331
+ ts = th.ones(init.size(0), device=init.device) * t1
332
+ x = last_step_fn(xs[-1], ts, model, **model_kwargs)
333
+ xs.append(x)
334
+
335
+ assert len(xs) == num_steps, "Samples does not match the number of steps"
336
+
337
+ return xs
338
+
339
+ return _sample
340
+
341
+ def sample_ode(
342
+ self,
343
+ *,
344
+ sampling_method="dopri5",
345
+ num_steps=50,
346
+ atol=1e-6,
347
+ rtol=1e-3,
348
+ reverse=False,
349
+ ):
350
+ """returns a sampling function with given ODE settings
351
+ Args:
352
+ - sampling_method: type of sampler used in solving the ODE; default to be Dopri5
353
+ - num_steps:
354
+ - fixed solver (Euler, Heun): the actual number of integration steps performed
355
+ - adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation
356
+ - atol: absolute error tolerance for the solver
357
+ - rtol: relative error tolerance for the solver
358
+ - reverse: whether solving the ODE in reverse (data to noise); default to False
359
+ """
360
+ drift = self.drift
361
+
362
+ t0, t1 = self.transport.check_interval(
363
+ self.transport.train_eps,
364
+ self.transport.sample_eps,
365
+ sde=False,
366
+ eval=True,
367
+ reverse=reverse,
368
+ last_step_size=0.0,
369
+ )
370
+
371
+ _ode = ode(
372
+ drift=drift,
373
+ t0=t0,
374
+ t1=t1,
375
+ sampler_type=sampling_method,
376
+ num_steps=num_steps,
377
+ atol=atol,
378
+ rtol=rtol,
379
+ )
380
+
381
+ return _ode.sample
382
+
383
+ def sample_ode_likelihood(
384
+ self,
385
+ *,
386
+ sampling_method="dopri5",
387
+ num_steps=50,
388
+ atol=1e-6,
389
+ rtol=1e-3,
390
+ ):
391
+
392
+ """returns a sampling function for calculating likelihood with given ODE settings
393
+ Args:
394
+ - sampling_method: type of sampler used in solving the ODE; default to be Dopri5
395
+ - num_steps:
396
+ - fixed solver (Euler, Heun): the actual number of integration steps performed
397
+ - adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation
398
+ - atol: absolute error tolerance for the solver
399
+ - rtol: relative error tolerance for the solver
400
+ """
401
+ def _likelihood_drift(x, t, model, **model_kwargs):
402
+ x, _ = x
403
+ eps = th.randint(2, x.size(), dtype=th.float, device=x.device) * 2 - 1
404
+ t = th.ones_like(t) * (1 - t)
405
+ with th.enable_grad():
406
+ x.requires_grad = True
407
+ grad = th.autograd.grad(th.sum(self.drift(x, t, model, **model_kwargs) * eps), x)[0]
408
+ logp_grad = th.sum(grad * eps, dim=tuple(range(1, len(x.size()))))
409
+ drift = self.drift(x, t, model, **model_kwargs)
410
+ return (-drift, logp_grad)
411
+
412
+ t0, t1 = self.transport.check_interval(
413
+ self.transport.train_eps,
414
+ self.transport.sample_eps,
415
+ sde=False,
416
+ eval=True,
417
+ reverse=False,
418
+ last_step_size=0.0,
419
+ )
420
+
421
+ _ode = ode(
422
+ drift=_likelihood_drift,
423
+ t0=t0,
424
+ t1=t1,
425
+ sampler_type=sampling_method,
426
+ num_steps=num_steps,
427
+ atol=atol,
428
+ rtol=rtol,
429
+ )
430
+
431
+ def _sample_fn(x, model, **model_kwargs):
432
+ init_logp = th.zeros(x.size(0)).to(x)
433
+ input = (x, init_logp)
434
+ drift, delta_logp = _ode.sample(input, model, **model_kwargs)
435
+ drift, delta_logp = drift[-1], delta_logp[-1]
436
+ prior_logp = self.transport.prior_logp(drift)
437
+ logp = prior_logp - delta_logp
438
+ return logp, drift
439
+
440
+ return _sample_fn
SiT_back/SiT_clean/transport/utils.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch as th
2
+
3
+ class EasyDict:
4
+
5
+ def __init__(self, sub_dict):
6
+ for k, v in sub_dict.items():
7
+ setattr(self, k, v)
8
+
9
+ def __getitem__(self, key):
10
+ return getattr(self, key)
11
+
12
+ def mean_flat(x):
13
+ """
14
+ Take the mean over all non-batch dimensions.
15
+ """
16
+ return th.mean(x, dim=list(range(1, len(x.size()))))
17
+
18
+ def log_state(state):
19
+ result = []
20
+
21
+ sorted_state = dict(sorted(state.items()))
22
+ for key, value in sorted_state.items():
23
+ # Check if the value is an instance of a class
24
+ if "<object" in str(value) or "object at" in str(value):
25
+ result.append(f"{key}: [{value.__class__.__name__}]")
26
+ else:
27
+ result.append(f"{key}: {value}")
28
+
29
+ return '\n'.join(result)
SiT_back/SiT_clean/wandb_utils.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import wandb
2
+ import torch
3
+ from torchvision.utils import make_grid
4
+ import torch.distributed as dist
5
+ from PIL import Image
6
+ import os
7
+ import argparse
8
+ import hashlib
9
+ import math
10
+
11
+
12
+ def is_main_process():
13
+ return dist.get_rank() == 0
14
+
15
+ def namespace_to_dict(namespace):
16
+ return {
17
+ k: namespace_to_dict(v) if isinstance(v, argparse.Namespace) else v
18
+ for k, v in vars(namespace).items()
19
+ }
20
+
21
+
22
+ def generate_run_id(exp_name):
23
+ # https://stackoverflow.com/questions/16008670/how-to-hash-a-string-into-8-digits
24
+ return str(int(hashlib.sha256(exp_name.encode('utf-8')).hexdigest(), 16) % 10 ** 8)
25
+
26
+
27
+ def initialize(args, entity, exp_name, project_name):
28
+ config_dict = namespace_to_dict(args)
29
+ wandb.login(key=os.environ["WANDB_KEY"])
30
+ wandb.init(
31
+ entity=entity,
32
+ project=project_name,
33
+ name=exp_name,
34
+ config=config_dict,
35
+ id=generate_run_id(exp_name),
36
+ resume="allow",
37
+ )
38
+
39
+
40
+ def log(stats, step=None):
41
+ if is_main_process():
42
+ wandb.log({k: v for k, v in stats.items()}, step=step)
43
+
44
+
45
+ def log_image(sample, step=None):
46
+ if is_main_process():
47
+ sample = array2grid(sample)
48
+ wandb.log({f"samples": wandb.Image(sample), "train_step": step})
49
+
50
+
51
+ def array2grid(x):
52
+ nrow = round(math.sqrt(x.size(0)))
53
+ x = make_grid(x, nrow=nrow, normalize=True, value_range=(-1,1))
54
+ x = x.mul(255).add_(0.5).clamp_(0,255).permute(1,2,0).to('cpu', torch.uint8).numpy()
55
+ return x