WayneW commited on
Commit
705a8fd
·
verified ·
1 Parent(s): f72be28

Upload folder using huggingface_hub

Browse files
config/data_config.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ action_stats:
2
+ min: [-2.5, -4] # [min_dx, min_dy]
3
+ max: [5, 4] # [max_dx, max_dy]
4
+
5
+ distance_diff_stats:
6
+ min: [-20] # [min]
7
+ max: [20] # [max]
8
+
9
+ avw_4k:
10
+ metric_waypoint_spacing: 0.15
config/eval_config.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ eval_distance:
2
+ eval_min_dist_cat: -16
3
+ eval_max_dist_cat: 16
4
+ eval_len_traj_pred: 16
5
+ eval_context_size: 4
6
+ traj_stride: 8
7
+
8
+ eval_datasets:
9
+ avw_4k:
10
+ data_folder: /path/to/dataset/avw_4k
11
+ test: /path/to/data_splits/avw_4k/test
12
+ goals_per_obs: 4
config/train_config_stage1.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ batch_size: 16
2
+ context_size: 4
3
+ datasets:
4
+ avw_4k:
5
+ data_folder: /path/to/dataset/avw_4k
6
+ goals_per_obs: 4
7
+ test: /path/to/data_splits/avw_4k/val
8
+ train: /path/to/data_splits/avw_4k/train
9
+ distance:
10
+ max_dist_cat: 16
11
+ min_dist_cat: -16
12
+ from_checkpoint: /path/to/pretrained/cdit_b_100000.pth.tar
13
+ grad_clip_val: 10.0
14
+ image_size: 224
15
+ len_traj_pred: 16
16
+ lr: 16.0e-05
17
+ model: AVCDiT-B/2
18
+ normalize: true
19
+ num_workers: 1
20
+ results_dir: logs
21
+ run_name: training_stage1
22
+ train: true
config/train_config_stage2.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ batch_size: 24
2
+ context_size: 4
3
+ datasets:
4
+ avw_4k:
5
+ data_folder: /path/to/dataset/avw_4k
6
+ goals_per_obs: 4
7
+ test: /path/to/data_splits/avw_4k/val
8
+ train: /path/to/data_splits/avw_4k/train
9
+ distance:
10
+ max_dist_cat: 16
11
+ min_dist_cat: -16
12
+ from_checkpoint: logs/training_stage1/checkpoints/latest.pth.tar
13
+ sample_rate: 16000
14
+ input_sr: 48000
15
+ tokenizer_a_path: /path/to/pretrained/soundstream.pt
16
+ grad_clip_val: 10.0
17
+ image_size: 224
18
+ len_traj_pred: 16
19
+ lr: 8.0e-4
20
+ model: AVCDiT-B/2
21
+ normalize: true
22
+ num_workers: 12
23
+ results_dir: logs
24
+ run_name: training_stage2
25
+ train: true
config/train_config_stage3.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ batch_size: 4
2
+ context_size: 4
3
+ datasets:
4
+ avw_4k:
5
+ data_folder: /path/to/dataset/avw_4k
6
+ goals_per_obs: 4
7
+ test: /path/to/data_splits/avw_4k/val
8
+ train: /path/to/data_splits/avw_4k/train
9
+ distance:
10
+ max_dist_cat: 16
11
+ min_dist_cat: -16
12
+ from_checkpoint: /path/to/pretrained/experts_merged.pth
13
+ sample_rate: 16000
14
+ input_sr: 48000
15
+ tokenizer_a_path: /path/to/pretrained/soundstream.pt
16
+ grad_clip_val: 10.0
17
+ image_size: 224
18
+ len_traj_pred: 16
19
+ lr: 16.0e-05
20
+ model: AVCDiT-B/2
21
+ normalize: true
22
+ num_workers: 12
23
+ results_dir: logs
24
+ run_name: training_stage3
25
+ train: true
datasets.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # NoMaD, GNM, ViNT: https://github.com/robodhruv/visualnav-transformer
9
+ # --------------------------------------------------------
10
+
11
+ import numpy as np
12
+ import torch
13
+ import os
14
+ from PIL import Image
15
+ from typing import Tuple
16
+ import yaml
17
+ import pickle
18
+ import tqdm
19
+ from torch.utils.data import Dataset
20
+ from misc import angle_difference, get_data_path, get_delta_np, normalize_data, to_local_coords
21
+ import torchaudio
22
+
23
+ class BaseDataset(Dataset):
24
+ def __init__(
25
+ self,
26
+ data_folder: str,
27
+ data_split_folder: str,
28
+ dataset_name: str,
29
+ image_size: Tuple[int, int],
30
+ min_dist_cat: int,
31
+ max_dist_cat: int,
32
+ len_traj_pred: int,
33
+ traj_stride: int,
34
+ context_size: int,
35
+ transform: object,
36
+ traj_names: str,
37
+ normalize: bool = True,
38
+ predefined_index: list = None,
39
+ goals_per_obs: int = 1,
40
+ ):
41
+ self.data_folder = data_folder
42
+ self.data_split_folder = data_split_folder
43
+ self.dataset_name = dataset_name
44
+ self.goals_per_obs = goals_per_obs
45
+
46
+
47
+ traj_names_file = os.path.join(data_split_folder, traj_names)
48
+ with open(traj_names_file, "r") as f:
49
+ file_lines = f.read()
50
+ self.traj_names = file_lines.split("\n")
51
+ if "" in self.traj_names:
52
+ self.traj_names.remove("")
53
+
54
+ self.image_size = image_size
55
+ self.distance_categories = list(range(min_dist_cat, max_dist_cat + 1))
56
+ self.min_dist_cat = self.distance_categories[0]
57
+ self.max_dist_cat = self.distance_categories[-1]
58
+ self.len_traj_pred = len_traj_pred
59
+ self.traj_stride = traj_stride
60
+
61
+ self.context_size = context_size
62
+ self.normalize = normalize
63
+
64
+ # load data/data_config.yaml
65
+ with open("config/data_config.yaml", "r") as f:
66
+ all_data_config = yaml.safe_load(f)
67
+
68
+ dataset_names = list(all_data_config.keys())
69
+ dataset_names.sort()
70
+ # use this index to retrieve the dataset name from the data_config.yaml
71
+ self.data_config = all_data_config[self.dataset_name]
72
+ self.transform = transform
73
+ self._load_index(predefined_index)
74
+ self.ACTION_STATS = {}
75
+ for key in all_data_config['action_stats']:
76
+ self.ACTION_STATS[key] = np.expand_dims(all_data_config['action_stats'][key], axis=0)
77
+ self.DISTANCE_DIFF_STATS = {} # [NEW]
78
+ for key in all_data_config['distance_diff_stats']: # [NEW]
79
+ self.DISTANCE_DIFF_STATS[key] = np.expand_dims(all_data_config['distance_diff_stats'][key], axis=0) # [NEW]
80
+
81
+ def _load_index(self, predefined_index) -> None:
82
+ """
83
+ Generates a list of tuples of (obs_traj_name, goal_traj_name, obs_time, goal_time) for each observation in the dataset
84
+ """
85
+ if predefined_index:
86
+ print(f"****** Using a predefined evaluation index... {predefined_index}******")
87
+ with open(predefined_index, "rb") as f:
88
+ self.index_to_data = pickle.load(f)
89
+ return
90
+ else:
91
+ print("****** Evaluating from NON PREDEFINED index... ******")
92
+ index_to_data_path = os.path.join(
93
+ self.data_split_folder,
94
+ f"dataset_dist_{self.min_dist_cat}_to_{self.max_dist_cat}_n{self.context_size}_len_traj_pred_{self.len_traj_pred}.pkl",
95
+ )
96
+
97
+ self.index_to_data, self.goals_index = self._build_index()
98
+ with open(index_to_data_path, "wb") as f:
99
+ pickle.dump((self.index_to_data, self.goals_index), f)
100
+
101
+ def _build_index(self, use_tqdm: bool = False):
102
+ """
103
+ Build an index consisting of tuples (trajectory name, time, max goal distance)
104
+ """
105
+ samples_index = []
106
+ goals_index = []
107
+
108
+ for traj_name in tqdm.tqdm(self.traj_names, disable=not use_tqdm, dynamic_ncols=True):
109
+ traj_data = self._get_trajectory(traj_name)
110
+ traj_len = len(traj_data["position"])
111
+ for goal_time in range(0, traj_len):
112
+ goals_index.append((traj_name, goal_time))
113
+
114
+ begin_time = self.context_size - 1
115
+ end_time = traj_len - self.len_traj_pred
116
+ for curr_time in range(begin_time, end_time, self.traj_stride):
117
+ max_goal_distance = min(self.max_dist_cat, traj_len - curr_time - 1)
118
+ min_goal_distance = max(self.min_dist_cat, -curr_time)
119
+ samples_index.append((traj_name, curr_time, min_goal_distance, max_goal_distance))
120
+
121
+ return samples_index, goals_index
122
+
123
+ def _get_trajectory(self, trajectory_name):
124
+ with open(os.path.join(self.data_folder, trajectory_name, "traj_data.pkl"), "rb") as f:
125
+ traj_data = pickle.load(f)
126
+ for k,v in traj_data.items():
127
+ traj_data[k] = v.astype('float')
128
+ return traj_data
129
+
130
+ def __len__(self) -> int:
131
+ return len(self.index_to_data)
132
+
133
+ def _compute_actions(self, traj_data, curr_time, goal_time):
134
+ start_index = curr_time
135
+ end_index = curr_time + self.len_traj_pred + 1
136
+ yaw = traj_data["yaw"][start_index:end_index]
137
+ positions = traj_data["position"][start_index:end_index]
138
+ goal_pos = traj_data["position"][goal_time]
139
+ goal_yaw = traj_data["yaw"][goal_time]
140
+ dist_window = traj_data["distance_to_target"][start_index:end_index] # shape (len_traj_pred+1,) # [NEW]
141
+ goal_dist = traj_data["distance_to_target"][goal_time] # shape (N,) or scalar # [NEW]
142
+
143
+ if len(yaw.shape) == 2:
144
+ yaw = yaw.squeeze(1)
145
+
146
+ if yaw.shape != (self.len_traj_pred + 1,):
147
+ raise ValueError("is used?")
148
+
149
+ waypoints_pos = to_local_coords(positions, positions[0], yaw[0])
150
+ waypoints_yaw = angle_difference(yaw[0], yaw)
151
+ actions = np.concatenate([waypoints_pos, waypoints_yaw.reshape(-1, 1)], axis=-1)
152
+ actions = actions[1:]
153
+
154
+ goal_pos = to_local_coords(goal_pos, positions[0], yaw[0])
155
+ goal_yaw = angle_difference(yaw[0], goal_yaw)
156
+
157
+ diffs_seq = (dist_window[0] - dist_window).reshape(-1, 1)[1:] # [NEW]
158
+ goal_diff = (dist_window[0] - goal_dist).reshape(-1, 1) # [NEW]
159
+
160
+ if self.normalize:
161
+ actions[:, :2] /= self.data_config["metric_waypoint_spacing"]
162
+ goal_pos[:, :2] /= self.data_config["metric_waypoint_spacing"]
163
+ diffs_seq /= self.data_config["metric_waypoint_spacing"] # [NEW]
164
+ goal_diff /= self.data_config["metric_waypoint_spacing"] # [NEW]
165
+
166
+ goal_pos = np.concatenate([goal_pos, goal_yaw.reshape(-1, 1)], axis=-1)
167
+ return actions, goal_pos, diffs_seq, goal_diff
168
+
169
+ class TrainingDataset(BaseDataset):
170
+ def __init__(
171
+ self,
172
+ data_folder: str,
173
+ data_split_folder: str,
174
+ dataset_name: str,
175
+ image_size: Tuple[int, int],
176
+ min_dist_cat: int,
177
+ max_dist_cat: int,
178
+ len_traj_pred: int,
179
+ traj_stride: int,
180
+ context_size: int,
181
+ transform: object,
182
+ traj_names: str = 'traj_names.txt',
183
+ normalize: bool = True,
184
+ predefined_index: list = None,
185
+ goals_per_obs: int = 1,
186
+ # sample_rate: int = 16000,
187
+ # target_len: int = 7840
188
+ sample_rate: int = 16000,
189
+ input_sr: int = 48000,
190
+ evaluate: bool = False
191
+ ):
192
+ super().__init__(data_folder, data_split_folder, dataset_name, image_size, min_dist_cat, max_dist_cat,
193
+ len_traj_pred, traj_stride, context_size, transform, traj_names, normalize, predefined_index, goals_per_obs)
194
+ self.resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=input_sr, lowpass_filter_width=64)
195
+ self.evaluate = evaluate
196
+
197
+ def __getitem__(self, i: int) -> Tuple[torch.Tensor]:
198
+ try:
199
+ f_curr, curr_time, min_goal_dist, max_goal_dist = self.index_to_data[i]
200
+ goal_offset = np.random.randint(min_goal_dist, max_goal_dist + 1, size=(self.goals_per_obs))
201
+ goal_time = (curr_time + goal_offset).astype('int')
202
+ rel_time = (goal_offset).astype('float')/(128.) # TODO: refactor, currently a fixed const
203
+
204
+ context_times = list(range(curr_time - self.context_size + 1, curr_time + 1))
205
+ context = [(f_curr, t) for t in context_times] + [(f_curr, t) for t in goal_time]
206
+
207
+ obs_image = torch.stack([self.transform(Image.open(get_data_path(self.data_folder, f, t))) for f, t in context])
208
+ obs_audio = torch.stack([torchaudio.load(get_data_path(self.data_folder, f, t, data_type="audio"))[0] for f, t in context])
209
+ if self.evaluate:
210
+ orig_obs_audio = obs_audio
211
+ obs_audio = self.resampler(obs_audio)
212
+
213
+ # Load other trajectory data
214
+ curr_traj_data = self._get_trajectory(f_curr)
215
+
216
+ # Compute actions
217
+ _, goal_pos, _, goal_diff = self._compute_actions(curr_traj_data, curr_time, goal_time)
218
+ goal_pos[:, :2] = normalize_data(goal_pos[:, :2], self.ACTION_STATS)
219
+ goal_diff = normalize_data(goal_diff, self.DISTANCE_DIFF_STATS)
220
+
221
+ if self.evaluate:
222
+ return (
223
+ torch.as_tensor(obs_image, dtype=torch.float32),
224
+ torch.as_tensor(obs_audio, dtype=torch.float32),
225
+ torch.as_tensor(goal_pos, dtype=torch.float32),
226
+ torch.as_tensor(goal_diff, dtype=torch.float32),
227
+ torch.as_tensor(rel_time, dtype=torch.float32),
228
+ torch.as_tensor(orig_obs_audio, dtype=torch.float32),
229
+ )
230
+ else:
231
+ return (
232
+ torch.as_tensor(obs_image, dtype=torch.float32),
233
+ torch.as_tensor(obs_audio, dtype=torch.float32),
234
+ torch.as_tensor(goal_pos, dtype=torch.float32),
235
+ torch.as_tensor(goal_diff, dtype=torch.float32),
236
+ torch.as_tensor(rel_time, dtype=torch.float32),
237
+ )
238
+ except Exception as e:
239
+ print(f"Exception in {self.dataset_name}", e)
240
+ raise Exception(e)
241
+
242
+ class EvalDataset(BaseDataset):
243
+ def __init__(
244
+ self,
245
+ data_folder: str,
246
+ data_split_folder: str,
247
+ dataset_name: str,
248
+ image_size: Tuple[int, int],
249
+ min_dist_cat: int,
250
+ max_dist_cat: int,
251
+ len_traj_pred: int,
252
+ traj_stride: int,
253
+ context_size: int,
254
+ transform: object,
255
+ traj_names: str,
256
+ normalize: bool = True,
257
+ predefined_index: list = None,
258
+ goals_per_obs: int = 1,
259
+ sample_rate: int = 16000,
260
+ input_sr: int = 48000
261
+ ):
262
+ super().__init__(data_folder, data_split_folder, dataset_name, image_size, min_dist_cat, max_dist_cat,
263
+ len_traj_pred, traj_stride, context_size, transform, traj_names, normalize, predefined_index, goals_per_obs)
264
+ self.resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=input_sr, lowpass_filter_width=64)
265
+
266
+ def __getitem__(self, i: int) -> Tuple[torch.Tensor]:
267
+ try:
268
+ f_curr, curr_time, _, _ = self.index_to_data[i]
269
+ context_times = list(range(curr_time - self.context_size + 1, curr_time + 1))
270
+ pred_times = list(range(curr_time + 1, curr_time + self.len_traj_pred + 1))
271
+
272
+ context = [(f_curr, t) for t in context_times]
273
+ pred = [(f_curr, t) for t in pred_times]
274
+
275
+ obs_image = torch.stack([self.transform(Image.open(get_data_path(self.data_folder, f, t))) for f, t in context])
276
+ pred_image = torch.stack([self.transform(Image.open(get_data_path(self.data_folder, f, t))) for f, t in pred])
277
+
278
+ orig_obs_audio = torch.stack([torchaudio.load(get_data_path(self.data_folder, f, t, data_type="audio"))[0] for f, t in context])
279
+ orig_pred_audio = torch.stack([torchaudio.load(get_data_path(self.data_folder, f, t, data_type="audio"))[0] for f, t in pred])
280
+
281
+ obs_audio = self.resampler(orig_obs_audio)
282
+ pred_audio = self.resampler(orig_pred_audio)
283
+
284
+ curr_traj_data = self._get_trajectory(f_curr)
285
+
286
+ # Compute actions
287
+ actions, _, diffs_seq, _ = self._compute_actions(curr_traj_data, curr_time, np.array([curr_time+1])) # last argument is dummy goal
288
+ actions[:, :2] = normalize_data(actions[:, :2], self.ACTION_STATS)
289
+ diffs_seq = normalize_data(diffs_seq, self.DISTANCE_DIFF_STATS)
290
+
291
+ delta = get_delta_np(actions)
292
+ diffs_seq = get_delta_np(diffs_seq)
293
+
294
+ return (
295
+ torch.tensor([i], dtype=torch.float32), # for logging purposes
296
+ torch.as_tensor(obs_image, dtype=torch.float32),
297
+ torch.as_tensor(pred_image, dtype=torch.float32),
298
+ torch.as_tensor(obs_audio, dtype=torch.float32),
299
+ torch.as_tensor(pred_audio, dtype=torch.float32),
300
+ torch.as_tensor(diffs_seq, dtype=torch.float32),
301
+ torch.as_tensor(delta, dtype=torch.float32),
302
+ torch.as_tensor(orig_obs_audio, dtype=torch.float32),
303
+ torch.as_tensor(orig_pred_audio, dtype=torch.float32),
304
+ )
305
+ except Exception as e:
306
+ print(f"Exception in {self.dataset_name}", e)
307
+ raise Exception(e)
diffusion/__init__.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import gaussian_diffusion as gd_orig
2
+ from . import gaussian_diffusion_dual as gd_dual
3
+ # from .respace import SpacedDiffusion, space_timesteps
4
+
5
+
6
+ def create_diffusion(
7
+ timestep_respacing,
8
+ noise_schedule="linear",
9
+ use_kl=False,
10
+ sigma_small=False,
11
+ predict_xstart=False,
12
+ learn_sigma=True,
13
+ rescale_learned_sigmas=False,
14
+ diffusion_steps=1000,
15
+ dual=False
16
+ ):
17
+ if dual:
18
+ print("Using DUAL diffusion")
19
+ from .respace_dual import SpacedDiffusion, space_timesteps
20
+ gd_module = gd_dual
21
+ else:
22
+ print("Using SINGLE diffusion")
23
+ from .respace import SpacedDiffusion, space_timesteps
24
+ gd_module = gd_orig
25
+
26
+ betas = gd_module.get_named_beta_schedule(noise_schedule, diffusion_steps)
27
+ # betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
28
+ if use_kl:
29
+ loss_type = gd_module.LossType.RESCALED_KL
30
+ elif rescale_learned_sigmas:
31
+ loss_type = gd_module.LossType.RESCALED_MSE
32
+ else:
33
+ loss_type = gd_module.LossType.MSE
34
+ if timestep_respacing is None or timestep_respacing == "":
35
+ timestep_respacing = [diffusion_steps]
36
+ return SpacedDiffusion(
37
+ use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
38
+ betas=betas,
39
+ model_mean_type=(
40
+ gd_module.ModelMeanType.EPSILON if not predict_xstart else gd_module.ModelMeanType.START_X
41
+ ),
42
+ model_var_type=(
43
+ (
44
+ gd_module.ModelVarType.FIXED_LARGE
45
+ if not sigma_small
46
+ else gd_module.ModelVarType.FIXED_SMALL
47
+ )
48
+ if not learn_sigma
49
+ else gd_module.ModelVarType.LEARNED_RANGE
50
+ ),
51
+ loss_type=loss_type
52
+ # rescale_timesteps=rescale_timesteps,
53
+ )
diffusion/diffusion_utils.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch as th
2
+ import numpy as np
3
+
4
+
5
+ def normal_kl(mean1, logvar1, mean2, logvar2):
6
+ """
7
+ Compute the KL divergence between two gaussians.
8
+ Shapes are automatically broadcasted, so batches can be compared to
9
+ scalars, among other use cases.
10
+ """
11
+ tensor = None
12
+ for obj in (mean1, logvar1, mean2, logvar2):
13
+ if isinstance(obj, th.Tensor):
14
+ tensor = obj
15
+ break
16
+ assert tensor is not None, "at least one argument must be a Tensor"
17
+
18
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
19
+ # Tensors, but it does not work for th.exp().
20
+ logvar1, logvar2 = [
21
+ x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
22
+ for x in (logvar1, logvar2)
23
+ ]
24
+
25
+ return 0.5 * (
26
+ -1.0
27
+ + logvar2
28
+ - logvar1
29
+ + th.exp(logvar1 - logvar2)
30
+ + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
31
+ )
32
+
33
+
34
+ def approx_standard_normal_cdf(x):
35
+ """
36
+ A fast approximation of the cumulative distribution function of the
37
+ standard normal.
38
+ """
39
+ return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
40
+
41
+
42
+ def continuous_gaussian_log_likelihood(x, *, means, log_scales):
43
+ """
44
+ Compute the log-likelihood of a continuous Gaussian distribution.
45
+ :param x: the targets
46
+ :param means: the Gaussian mean Tensor.
47
+ :param log_scales: the Gaussian log stddev Tensor.
48
+ :return: a tensor like x of log probabilities (in nats).
49
+ """
50
+ centered_x = x - means
51
+ inv_stdv = th.exp(-log_scales)
52
+ normalized_x = centered_x * inv_stdv
53
+ log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x)
54
+ return log_probs
55
+
56
+
57
+ def discretized_gaussian_log_likelihood(x, *, means, log_scales):
58
+ """
59
+ Compute the log-likelihood of a Gaussian distribution discretizing to a
60
+ given image.
61
+ :param x: the target images. It is assumed that this was uint8 values,
62
+ rescaled to the range [-1, 1].
63
+ :param means: the Gaussian mean Tensor.
64
+ :param log_scales: the Gaussian log stddev Tensor.
65
+ :return: a tensor like x of log probabilities (in nats).
66
+ """
67
+ assert x.shape == means.shape == log_scales.shape
68
+ centered_x = x - means
69
+ inv_stdv = th.exp(-log_scales)
70
+ plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
71
+ cdf_plus = approx_standard_normal_cdf(plus_in)
72
+ min_in = inv_stdv * (centered_x - 1.0 / 255.0)
73
+ cdf_min = approx_standard_normal_cdf(min_in)
74
+ log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
75
+ log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
76
+ cdf_delta = cdf_plus - cdf_min
77
+ log_probs = th.where(
78
+ x < -0.999,
79
+ log_cdf_plus,
80
+ th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
81
+ )
82
+ assert log_probs.shape == x.shape
83
+ return log_probs
diffusion/gaussian_diffusion.py ADDED
@@ -0,0 +1,870 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import numpy as np
4
+ import torch as th
5
+ import enum
6
+
7
+ from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl
8
+ import torch.nn.functional as F
9
+ import torch
10
+
11
+
12
+ def mean_flat(tensor):
13
+ """
14
+ Take the mean over all non-batch dimensions.
15
+ """
16
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
17
+
18
+
19
+ class ModelMeanType(enum.Enum):
20
+ """
21
+ Which type of output the model predicts.
22
+ """
23
+
24
+ PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
25
+ START_X = enum.auto() # the model predicts x_0
26
+ EPSILON = enum.auto() # the model predicts epsilon
27
+
28
+
29
+ class ModelVarType(enum.Enum):
30
+ """
31
+ What is used as the model's output variance.
32
+ The LEARNED_RANGE option has been added to allow the model to predict
33
+ values between FIXED_SMALL and FIXED_LARGE, making its job easier.
34
+ """
35
+
36
+ LEARNED = enum.auto()
37
+ FIXED_SMALL = enum.auto()
38
+ FIXED_LARGE = enum.auto()
39
+ LEARNED_RANGE = enum.auto()
40
+
41
+
42
+ class LossType(enum.Enum):
43
+ MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
44
+ RESCALED_MSE = (
45
+ enum.auto()
46
+ ) # use raw MSE loss (with RESCALED_KL when learning variances)
47
+ KL = enum.auto() # use the variational lower-bound
48
+ RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
49
+
50
+ def is_vb(self):
51
+ return self == LossType.KL or self == LossType.RESCALED_KL
52
+
53
+
54
+ def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
55
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
56
+ warmup_time = int(num_diffusion_timesteps * warmup_frac)
57
+ betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
58
+ return betas
59
+
60
+
61
+ def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
62
+ """
63
+ This is the deprecated API for creating beta schedules.
64
+ See get_named_beta_schedule() for the new library of schedules.
65
+ """
66
+ if beta_schedule == "quad":
67
+ betas = (
68
+ np.linspace(
69
+ beta_start ** 0.5,
70
+ beta_end ** 0.5,
71
+ num_diffusion_timesteps,
72
+ dtype=np.float64,
73
+ )
74
+ ** 2
75
+ )
76
+ elif beta_schedule == "linear":
77
+ betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
78
+ elif beta_schedule == "warmup10":
79
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
80
+ elif beta_schedule == "warmup50":
81
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
82
+ elif beta_schedule == "const":
83
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
84
+ elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
85
+ betas = 1.0 / np.linspace(
86
+ num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
87
+ )
88
+ else:
89
+ raise NotImplementedError(beta_schedule)
90
+ assert betas.shape == (num_diffusion_timesteps,)
91
+ return betas
92
+
93
+
94
+ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
95
+ """
96
+ Get a pre-defined beta schedule for the given name.
97
+ The beta schedule library consists of beta schedules which remain similar
98
+ in the limit of num_diffusion_timesteps.
99
+ Beta schedules may be added, but should not be removed or changed once
100
+ they are committed to maintain backwards compatibility.
101
+ """
102
+ if schedule_name == "linear":
103
+ # Linear schedule from Ho et al, extended to work for any number of
104
+ # diffusion steps.
105
+ scale = 1000 / num_diffusion_timesteps
106
+ return get_beta_schedule(
107
+ "linear",
108
+ beta_start=scale * 0.0001,
109
+ beta_end=scale * 0.02,
110
+ num_diffusion_timesteps=num_diffusion_timesteps,
111
+ )
112
+ elif schedule_name == "squaredcos_cap_v2":
113
+ return betas_for_alpha_bar(
114
+ num_diffusion_timesteps,
115
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
116
+ )
117
+ else:
118
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
119
+
120
+
121
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
122
+ """
123
+ Create a beta schedule that discretizes the given alpha_t_bar function,
124
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
125
+ :param num_diffusion_timesteps: the number of betas to produce.
126
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
127
+ produces the cumulative product of (1-beta) up to that
128
+ part of the diffusion process.
129
+ :param max_beta: the maximum beta to use; use values lower than 1 to
130
+ prevent singularities.
131
+ """
132
+ betas = []
133
+ for i in range(num_diffusion_timesteps):
134
+ t1 = i / num_diffusion_timesteps
135
+ t2 = (i + 1) / num_diffusion_timesteps
136
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
137
+ return np.array(betas)
138
+
139
+
140
+ class GaussianDiffusion:
141
+ """
142
+ Utilities for training and sampling diffusion models.
143
+ Original ported from this codebase:
144
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
145
+ :param betas: a 1-D numpy array of betas for each diffusion timestep,
146
+ starting at T and going to 1.
147
+ """
148
+
149
+ def __init__(
150
+ self,
151
+ *,
152
+ betas,
153
+ model_mean_type,
154
+ model_var_type,
155
+ loss_type
156
+ ):
157
+
158
+ self.model_mean_type = model_mean_type
159
+ self.model_var_type = model_var_type
160
+ self.loss_type = loss_type
161
+
162
+ # Use float64 for accuracy.
163
+ betas = np.array(betas, dtype=np.float64)
164
+ self.betas = betas
165
+ assert len(betas.shape) == 1, "betas must be 1-D"
166
+ assert (betas > 0).all() and (betas <= 1).all()
167
+
168
+ self.num_timesteps = int(betas.shape[0])
169
+
170
+ alphas = 1.0 - betas
171
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
172
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
173
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
174
+ assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
175
+
176
+ # calculations for diffusion q(x_t | x_{t-1}) and others
177
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
178
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
179
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
180
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
181
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
182
+
183
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
184
+ self.posterior_variance = (
185
+ betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
186
+ )
187
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
188
+ self.posterior_log_variance_clipped = np.log(
189
+ np.append(self.posterior_variance[1], self.posterior_variance[1:])
190
+ ) if len(self.posterior_variance) > 1 else np.array([])
191
+
192
+ self.posterior_mean_coef1 = (
193
+ betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
194
+ )
195
+ self.posterior_mean_coef2 = (
196
+ (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
197
+ )
198
+
199
+ def q_mean_variance(self, x_start, t):
200
+ """
201
+ Get the distribution q(x_t | x_0).
202
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
203
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
204
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
205
+ """
206
+ mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
207
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
208
+ log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
209
+ return mean, variance, log_variance
210
+
211
+ def q_sample(self, x_start, t, noise=None):
212
+ """
213
+ Diffuse the data for a given number of diffusion steps.
214
+ In other words, sample from q(x_t | x_0).
215
+ :param x_start: the initial data batch.
216
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
217
+ :param noise: if specified, the split-out normal noise.
218
+ :return: A noisy version of x_start.
219
+ """
220
+ if noise is None:
221
+ noise = th.randn_like(x_start)
222
+ assert noise.shape == x_start.shape
223
+ return (
224
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
225
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
226
+ )
227
+
228
+ def q_posterior_mean_variance(self, x_start, x_t, t):
229
+ """
230
+ Compute the mean and variance of the diffusion posterior:
231
+ q(x_{t-1} | x_t, x_0)
232
+ """
233
+ assert x_start.shape == x_t.shape
234
+ posterior_mean = (
235
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
236
+ + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
237
+ )
238
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
239
+ posterior_log_variance_clipped = _extract_into_tensor(
240
+ self.posterior_log_variance_clipped, t, x_t.shape
241
+ )
242
+ assert (
243
+ posterior_mean.shape[0]
244
+ == posterior_variance.shape[0]
245
+ == posterior_log_variance_clipped.shape[0]
246
+ == x_start.shape[0]
247
+ )
248
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
249
+
250
+ def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None):
251
+ """
252
+ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
253
+ the initial x, x_0.
254
+ :param model: the model, which takes a signal and a batch of timesteps
255
+ as input.
256
+ :param x: the [N x C x ...] tensor at time t.
257
+ :param t: a 1-D Tensor of timesteps.
258
+ :param clip_denoised: if True, clip the denoised signal into [-1, 1].
259
+ :param denoised_fn: if not None, a function which applies to the
260
+ x_start prediction before it is used to sample. Applies before
261
+ clip_denoised.
262
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
263
+ pass to the model. This can be used for conditioning.
264
+ :return: a dict with the following keys:
265
+ - 'mean': the model mean output.
266
+ - 'variance': the model variance output.
267
+ - 'log_variance': the log of 'variance'.
268
+ - 'pred_xstart': the prediction for x_0.
269
+ """
270
+ if model_kwargs is None:
271
+ model_kwargs = {}
272
+
273
+ B, C = x.shape[:2]
274
+ assert t.shape == (B,)
275
+ model_output = model(x, t, **model_kwargs)
276
+ if isinstance(model_output, tuple):
277
+ model_output, extra = model_output
278
+ else:
279
+ extra = None
280
+
281
+ if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
282
+ assert model_output.shape == (B, C * 2, *x.shape[2:])
283
+ model_output, model_var_values = th.split(model_output, C, dim=1)
284
+ min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
285
+ max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
286
+ # The model_var_values is [-1, 1] for [min_var, max_var].
287
+ frac = (model_var_values + 1) / 2
288
+ model_log_variance = frac * max_log + (1 - frac) * min_log
289
+ model_variance = th.exp(model_log_variance)
290
+ else:
291
+ model_variance, model_log_variance = {
292
+ # for fixedlarge, we set the initial (log-)variance like so
293
+ # to get a better decoder log likelihood.
294
+ ModelVarType.FIXED_LARGE: (
295
+ np.append(self.posterior_variance[1], self.betas[1:]),
296
+ np.log(np.append(self.posterior_variance[1], self.betas[1:])),
297
+ ),
298
+ ModelVarType.FIXED_SMALL: (
299
+ self.posterior_variance,
300
+ self.posterior_log_variance_clipped,
301
+ ),
302
+ }[self.model_var_type]
303
+ model_variance = _extract_into_tensor(model_variance, t, x.shape)
304
+ model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
305
+
306
+ def process_xstart(x):
307
+ if denoised_fn is not None:
308
+ x = denoised_fn(x)
309
+ if clip_denoised:
310
+ return x.clamp(-1, 1)
311
+ return x
312
+
313
+ if self.model_mean_type == ModelMeanType.START_X:
314
+ pred_xstart = process_xstart(model_output)
315
+ else:
316
+ pred_xstart = process_xstart(
317
+ self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
318
+ )
319
+ model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
320
+
321
+ assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
322
+ return {
323
+ "mean": model_mean,
324
+ "variance": model_variance,
325
+ "log_variance": model_log_variance,
326
+ "pred_xstart": pred_xstart,
327
+ "extra": extra,
328
+ }
329
+
330
+ def _predict_xstart_from_eps(self, x_t, t, eps):
331
+ assert x_t.shape == eps.shape
332
+ return (
333
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
334
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
335
+ )
336
+
337
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
338
+ return (
339
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
340
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
341
+
342
+ def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
343
+ """
344
+ Compute the mean for the previous step, given a function cond_fn that
345
+ computes the gradient of a conditional log probability with respect to
346
+ x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
347
+ condition on y.
348
+ This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
349
+ """
350
+ gradient = cond_fn(x, t, **model_kwargs)
351
+ new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
352
+ return new_mean
353
+
354
+ def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
355
+ """
356
+ Compute what the p_mean_variance output would have been, should the
357
+ model's score function be conditioned by cond_fn.
358
+ See condition_mean() for details on cond_fn.
359
+ Unlike condition_mean(), this instead uses the conditioning strategy
360
+ from Song et al (2020).
361
+ """
362
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
363
+
364
+ eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
365
+ eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
366
+
367
+ out = p_mean_var.copy()
368
+ out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
369
+ out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
370
+ return out
371
+
372
+ def p_sample(
373
+ self,
374
+ model,
375
+ x,
376
+ t,
377
+ clip_denoised=True,
378
+ denoised_fn=None,
379
+ cond_fn=None,
380
+ model_kwargs=None,
381
+ ):
382
+ """
383
+ Sample x_{t-1} from the model at the given timestep.
384
+ :param model: the model to sample from.
385
+ :param x: the current tensor at x_{t-1}.
386
+ :param t: the value of t, starting at 0 for the first diffusion step.
387
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
388
+ :param denoised_fn: if not None, a function which applies to the
389
+ x_start prediction before it is used to sample.
390
+ :param cond_fn: if not None, this is a gradient function that acts
391
+ similarly to the model.
392
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
393
+ pass to the model. This can be used for conditioning.
394
+ :return: a dict containing the following keys:
395
+ - 'sample': a random sample from the model.
396
+ - 'pred_xstart': a prediction of x_0.
397
+ """
398
+ out = self.p_mean_variance(
399
+ model,
400
+ x,
401
+ t,
402
+ clip_denoised=clip_denoised,
403
+ denoised_fn=denoised_fn,
404
+ model_kwargs=model_kwargs,
405
+ )
406
+ noise = th.randn_like(x)
407
+ nonzero_mask = (
408
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
409
+ ) # no noise when t == 0
410
+ if cond_fn is not None:
411
+ out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
412
+ sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
413
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
414
+
415
+ def p_sample_loop(
416
+ self,
417
+ model,
418
+ shape,
419
+ noise=None,
420
+ clip_denoised=True,
421
+ denoised_fn=None,
422
+ cond_fn=None,
423
+ model_kwargs=None,
424
+ device=None,
425
+ progress=False,
426
+ ):
427
+ """
428
+ Generate samples from the model.
429
+ :param model: the model module.
430
+ :param shape: the shape of the samples, (N, C, H, W).
431
+ :param noise: if specified, the noise from the encoder to sample.
432
+ Should be of the same shape as `shape`.
433
+ :param clip_denoised: if True, clip x_start predictions to [-1, 1].
434
+ :param denoised_fn: if not None, a function which applies to the
435
+ x_start prediction before it is used to sample.
436
+ :param cond_fn: if not None, this is a gradient function that acts
437
+ similarly to the model.
438
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
439
+ pass to the model. This can be used for conditioning.
440
+ :param device: if specified, the device to create the samples on.
441
+ If not specified, use a model parameter's device.
442
+ :param progress: if True, show a tqdm progress bar.
443
+ :return: a non-differentiable batch of samples.
444
+ """
445
+ final = None
446
+ for sample in self.p_sample_loop_progressive(
447
+ model,
448
+ shape,
449
+ noise=noise,
450
+ clip_denoised=clip_denoised,
451
+ denoised_fn=denoised_fn,
452
+ cond_fn=cond_fn,
453
+ model_kwargs=model_kwargs,
454
+ device=device,
455
+ progress=progress,
456
+ ):
457
+ final = sample
458
+ return final["sample"]
459
+
460
+ def p_sample_loop_progressive(
461
+ self,
462
+ model,
463
+ shape,
464
+ noise=None,
465
+ clip_denoised=True,
466
+ denoised_fn=None,
467
+ cond_fn=None,
468
+ model_kwargs=None,
469
+ device=None,
470
+ progress=False,
471
+ ):
472
+ """
473
+ Generate samples from the model and yield intermediate samples from
474
+ each timestep of diffusion.
475
+ Arguments are the same as p_sample_loop().
476
+ Returns a generator over dicts, where each dict is the return value of
477
+ p_sample().
478
+ """
479
+ if device is None:
480
+ device = next(model.parameters()).device
481
+ assert isinstance(shape, (tuple, list))
482
+ if noise is not None:
483
+ img = noise
484
+ else:
485
+ img = th.randn(*shape, device=device)
486
+ indices = list(range(self.num_timesteps))[::-1]
487
+
488
+ if progress:
489
+ # Lazy import so that we don't depend on tqdm.
490
+ from tqdm.auto import tqdm
491
+
492
+ indices = tqdm(indices)
493
+
494
+ for i in indices:
495
+ t = th.tensor([i] * shape[0], device=device)
496
+ with th.no_grad():
497
+ out = self.p_sample(
498
+ model,
499
+ img,
500
+ t,
501
+ clip_denoised=clip_denoised,
502
+ denoised_fn=denoised_fn,
503
+ cond_fn=cond_fn,
504
+ model_kwargs=model_kwargs,
505
+ )
506
+ yield out
507
+ img = out["sample"]
508
+
509
+ def ddim_sample(
510
+ self,
511
+ model,
512
+ x,
513
+ t,
514
+ clip_denoised=True,
515
+ denoised_fn=None,
516
+ cond_fn=None,
517
+ model_kwargs=None,
518
+ eta=0.0,
519
+ ):
520
+ """
521
+ Sample x_{t-1} from the model using DDIM.
522
+ Same usage as p_sample().
523
+ """
524
+ out = self.p_mean_variance(
525
+ model,
526
+ x,
527
+ t,
528
+ clip_denoised=clip_denoised,
529
+ denoised_fn=denoised_fn,
530
+ model_kwargs=model_kwargs,
531
+ )
532
+ if cond_fn is not None:
533
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
534
+
535
+ # Usually our model outputs epsilon, but we re-derive it
536
+ # in case we used x_start or x_prev prediction.
537
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
538
+
539
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
540
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
541
+ sigma = (
542
+ eta
543
+ * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
544
+ * th.sqrt(1 - alpha_bar / alpha_bar_prev)
545
+ )
546
+ # Equation 12.
547
+ noise = th.randn_like(x)
548
+ mean_pred = (
549
+ out["pred_xstart"] * th.sqrt(alpha_bar_prev)
550
+ + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
551
+ )
552
+ nonzero_mask = (
553
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
554
+ ) # no noise when t == 0
555
+ sample = mean_pred + nonzero_mask * sigma * noise
556
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
557
+
558
+ def ddim_reverse_sample(
559
+ self,
560
+ model,
561
+ x,
562
+ t,
563
+ clip_denoised=True,
564
+ denoised_fn=None,
565
+ cond_fn=None,
566
+ model_kwargs=None,
567
+ eta=0.0,
568
+ ):
569
+ """
570
+ Sample x_{t+1} from the model using DDIM reverse ODE.
571
+ """
572
+ assert eta == 0.0, "Reverse ODE only for deterministic path"
573
+ out = self.p_mean_variance(
574
+ model,
575
+ x,
576
+ t,
577
+ clip_denoised=clip_denoised,
578
+ denoised_fn=denoised_fn,
579
+ model_kwargs=model_kwargs,
580
+ )
581
+ if cond_fn is not None:
582
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
583
+ # Usually our model outputs epsilon, but we re-derive it
584
+ # in case we used x_start or x_prev prediction.
585
+ eps = (
586
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
587
+ - out["pred_xstart"]
588
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
589
+ alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
590
+
591
+ # Equation 12. reversed
592
+ mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
593
+
594
+ return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
595
+
596
+ def ddim_sample_loop(
597
+ self,
598
+ model,
599
+ shape,
600
+ noise=None,
601
+ clip_denoised=True,
602
+ denoised_fn=None,
603
+ cond_fn=None,
604
+ model_kwargs=None,
605
+ device=None,
606
+ progress=False,
607
+ eta=0.0,
608
+ ):
609
+ """
610
+ Generate samples from the model using DDIM.
611
+ Same usage as p_sample_loop().
612
+ """
613
+ final = None
614
+ for sample in self.ddim_sample_loop_progressive(
615
+ model,
616
+ shape,
617
+ noise=noise,
618
+ clip_denoised=clip_denoised,
619
+ denoised_fn=denoised_fn,
620
+ cond_fn=cond_fn,
621
+ model_kwargs=model_kwargs,
622
+ device=device,
623
+ progress=progress,
624
+ eta=eta,
625
+ ):
626
+ final = sample
627
+ return final["sample"]
628
+
629
+ def ddim_sample_loop_progressive(
630
+ self,
631
+ model,
632
+ shape,
633
+ noise=None,
634
+ clip_denoised=True,
635
+ denoised_fn=None,
636
+ cond_fn=None,
637
+ model_kwargs=None,
638
+ device=None,
639
+ progress=False,
640
+ eta=0.0,
641
+ ):
642
+ """
643
+ Use DDIM to sample from the model and yield intermediate samples from
644
+ each timestep of DDIM.
645
+ Same usage as p_sample_loop_progressive().
646
+ """
647
+ if device is None:
648
+ device = next(model.parameters()).device
649
+ assert isinstance(shape, (tuple, list))
650
+ if noise is not None:
651
+ img = noise
652
+ else:
653
+ img = th.randn(*shape, device=device)
654
+ indices = list(range(self.num_timesteps))[::-1]
655
+
656
+ if progress:
657
+ # Lazy import so that we don't depend on tqdm.
658
+ from tqdm.auto import tqdm
659
+
660
+ indices = tqdm(indices)
661
+
662
+ for i in indices:
663
+ t = th.tensor([i] * shape[0], device=device)
664
+ with th.no_grad():
665
+ out = self.ddim_sample(
666
+ model,
667
+ img,
668
+ t,
669
+ clip_denoised=clip_denoised,
670
+ denoised_fn=denoised_fn,
671
+ cond_fn=cond_fn,
672
+ model_kwargs=model_kwargs,
673
+ eta=eta,
674
+ )
675
+ yield out
676
+ img = out["sample"]
677
+
678
+ def _vb_terms_bpd(
679
+ self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
680
+ ):
681
+ """
682
+ Get a term for the variational lower-bound.
683
+ The resulting units are bits (rather than nats, as one might expect).
684
+ This allows for comparison to other papers.
685
+ :return: a dict with the following keys:
686
+ - 'output': a shape [N] tensor of NLLs or KLs.
687
+ - 'pred_xstart': the x_0 predictions.
688
+ """
689
+ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
690
+ x_start=x_start, x_t=x_t, t=t
691
+ )
692
+ out = self.p_mean_variance(
693
+ model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
694
+ )
695
+ kl = normal_kl(
696
+ true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
697
+ )
698
+ kl = mean_flat(kl) / np.log(2.0)
699
+
700
+ decoder_nll = -discretized_gaussian_log_likelihood(
701
+ x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
702
+ )
703
+ assert decoder_nll.shape == x_start.shape
704
+ decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
705
+
706
+ # At the first timestep return the decoder NLL,
707
+ # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
708
+ output = th.where((t == 0), decoder_nll, kl)
709
+ return {"output": output, "pred_xstart": out["pred_xstart"]}
710
+
711
+ def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
712
+ """
713
+ Compute training losses for a single timestep.
714
+ :param model: the model to evaluate loss on.
715
+ :param x_start: the [N x C x ...] tensor of inputs.
716
+ :param t: a batch of timestep indices.
717
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
718
+ pass to the model. This can be used for conditioning.
719
+ :param noise: if specified, the specific Gaussian noise to try to remove.
720
+ :return: a dict with the key "loss" containing a tensor of shape [N].
721
+ Some mean or variance settings may also have other keys.
722
+ """
723
+ if model_kwargs is None:
724
+ model_kwargs = {}
725
+ if noise is None:
726
+ noise = th.randn_like(x_start)
727
+ x_t = self.q_sample(x_start, t, noise=noise)
728
+
729
+ terms = {}
730
+
731
+ if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
732
+ terms["loss"] = self._vb_terms_bpd(
733
+ model=model,
734
+ x_start=x_start,
735
+ x_t=x_t,
736
+ t=t,
737
+ clip_denoised=False,
738
+ model_kwargs=model_kwargs,
739
+ )["output"]
740
+ if self.loss_type == LossType.RESCALED_KL:
741
+ terms["loss"] *= self.num_timesteps
742
+ elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
743
+ model_output = model(x_t, t, **model_kwargs)
744
+
745
+ if self.model_var_type in [
746
+ ModelVarType.LEARNED,
747
+ ModelVarType.LEARNED_RANGE,
748
+ ]:
749
+ B, C = x_t.shape[:2]
750
+ assert model_output.shape == (B, C * 2, *x_t.shape[2:])
751
+ model_output, model_var_values = th.split(model_output, C, dim=1)
752
+ # Learn the variance using the variational bound, but don't let
753
+ # it affect our mean prediction.
754
+ frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
755
+ terms["vb"] = self._vb_terms_bpd(
756
+ model=lambda *args, r=frozen_out: r,
757
+ x_start=x_start,
758
+ x_t=x_t,
759
+ t=t,
760
+ clip_denoised=False,
761
+ )["output"]
762
+ if self.loss_type == LossType.RESCALED_MSE:
763
+ # Divide by 1000 for equivalence with initial implementation.
764
+ # Without a factor of 1/1000, the VB term hurts the MSE term.
765
+ terms["vb"] *= self.num_timesteps / 1000.0
766
+
767
+ target = {
768
+ ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
769
+ x_start=x_start, x_t=x_t, t=t
770
+ )[0],
771
+ ModelMeanType.START_X: x_start,
772
+ ModelMeanType.EPSILON: noise,
773
+ }[self.model_mean_type]
774
+ assert model_output.shape == target.shape == x_start.shape
775
+
776
+ terms["mse"] = mean_flat((target - model_output) ** 2)
777
+ if "vb" in terms:
778
+ terms["loss"] = terms["mse"] + terms["vb"]
779
+ else:
780
+ terms["loss"] = terms["mse"]
781
+ else:
782
+ raise NotImplementedError(self.loss_type)
783
+
784
+ return terms
785
+
786
+ def _prior_bpd(self, x_start):
787
+ """
788
+ Get the prior KL term for the variational lower-bound, measured in
789
+ bits-per-dim.
790
+ This term can't be optimized, as it only depends on the encoder.
791
+ :param x_start: the [N x C x ...] tensor of inputs.
792
+ :return: a batch of [N] KL values (in bits), one per batch element.
793
+ """
794
+ batch_size = x_start.shape[0]
795
+ t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
796
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
797
+ kl_prior = normal_kl(
798
+ mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
799
+ )
800
+ return mean_flat(kl_prior) / np.log(2.0)
801
+
802
+ def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
803
+ """
804
+ Compute the entire variational lower-bound, measured in bits-per-dim,
805
+ as well as other related quantities.
806
+ :param model: the model to evaluate loss on.
807
+ :param x_start: the [N x C x ...] tensor of inputs.
808
+ :param clip_denoised: if True, clip denoised samples.
809
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
810
+ pass to the model. This can be used for conditioning.
811
+ :return: a dict containing the following keys:
812
+ - total_bpd: the total variational lower-bound, per batch element.
813
+ - prior_bpd: the prior term in the lower-bound.
814
+ - vb: an [N x T] tensor of terms in the lower-bound.
815
+ - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
816
+ - mse: an [N x T] tensor of epsilon MSEs for each timestep.
817
+ """
818
+ device = x_start.device
819
+ batch_size = x_start.shape[0]
820
+
821
+ vb = []
822
+ xstart_mse = []
823
+ mse = []
824
+ for t in list(range(self.num_timesteps))[::-1]:
825
+ t_batch = th.tensor([t] * batch_size, device=device)
826
+ noise = th.randn_like(x_start)
827
+ x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
828
+ # Calculate VLB term at the current timestep
829
+ with th.no_grad():
830
+ out = self._vb_terms_bpd(
831
+ model,
832
+ x_start=x_start,
833
+ x_t=x_t,
834
+ t=t_batch,
835
+ clip_denoised=clip_denoised,
836
+ model_kwargs=model_kwargs,
837
+ )
838
+ vb.append(out["output"])
839
+ xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
840
+ eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
841
+ mse.append(mean_flat((eps - noise) ** 2))
842
+
843
+ vb = th.stack(vb, dim=1)
844
+ xstart_mse = th.stack(xstart_mse, dim=1)
845
+ mse = th.stack(mse, dim=1)
846
+
847
+ prior_bpd = self._prior_bpd(x_start)
848
+ total_bpd = vb.sum(dim=1) + prior_bpd
849
+ return {
850
+ "total_bpd": total_bpd,
851
+ "prior_bpd": prior_bpd,
852
+ "vb": vb,
853
+ "xstart_mse": xstart_mse,
854
+ "mse": mse,
855
+ }
856
+
857
+
858
+ def _extract_into_tensor(arr, timesteps, broadcast_shape):
859
+ """
860
+ Extract values from a 1-D numpy array for a batch of indices.
861
+ :param arr: the 1-D numpy array.
862
+ :param timesteps: a tensor of indices into the array to extract.
863
+ :param broadcast_shape: a larger shape of K dimensions with the batch
864
+ dimension equal to the length of timesteps.
865
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
866
+ """
867
+ res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
868
+ while len(res.shape) < len(broadcast_shape):
869
+ res = res[..., None]
870
+ return res + th.zeros(broadcast_shape, device=timesteps.device)
diffusion/gaussian_diffusion_dual.py ADDED
@@ -0,0 +1,975 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # Modified from OpenAI's diffusion repos
8
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
9
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
10
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
11
+
12
+
13
+ import math
14
+
15
+ import numpy as np
16
+ import torch as th
17
+ import enum
18
+
19
+ from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl
20
+
21
+
22
+ def mean_flat(tensor):
23
+ """
24
+ Take the mean over all non-batch dimensions.
25
+ """
26
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
27
+
28
+
29
+ class ModelMeanType(enum.Enum):
30
+ """
31
+ Which type of output the model predicts.
32
+ """
33
+
34
+ PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
35
+ START_X = enum.auto() # the model predicts x_0
36
+ EPSILON = enum.auto() # the model predicts epsilon
37
+
38
+
39
+ class ModelVarType(enum.Enum):
40
+ """
41
+ What is used as the model's output variance.
42
+ The LEARNED_RANGE option has been added to allow the model to predict
43
+ values between FIXED_SMALL and FIXED_LARGE, making its job easier.
44
+ """
45
+
46
+ LEARNED = enum.auto()
47
+ FIXED_SMALL = enum.auto()
48
+ FIXED_LARGE = enum.auto()
49
+ LEARNED_RANGE = enum.auto()
50
+
51
+
52
+ class LossType(enum.Enum):
53
+ MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
54
+ RESCALED_MSE = (
55
+ enum.auto()
56
+ ) # use raw MSE loss (with RESCALED_KL when learning variances)
57
+ KL = enum.auto() # use the variational lower-bound
58
+ RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
59
+
60
+ def is_vb(self):
61
+ return self == LossType.KL or self == LossType.RESCALED_KL
62
+
63
+
64
+ def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
65
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
66
+ warmup_time = int(num_diffusion_timesteps * warmup_frac)
67
+ betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
68
+ return betas
69
+
70
+
71
+ def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
72
+ """
73
+ This is the deprecated API for creating beta schedules.
74
+ See get_named_beta_schedule() for the new library of schedules.
75
+ """
76
+ if beta_schedule == "quad":
77
+ betas = (
78
+ np.linspace(
79
+ beta_start ** 0.5,
80
+ beta_end ** 0.5,
81
+ num_diffusion_timesteps,
82
+ dtype=np.float64,
83
+ )
84
+ ** 2
85
+ )
86
+ elif beta_schedule == "linear":
87
+ betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
88
+ elif beta_schedule == "warmup10":
89
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
90
+ elif beta_schedule == "warmup50":
91
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
92
+ elif beta_schedule == "const":
93
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
94
+ elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
95
+ betas = 1.0 / np.linspace(
96
+ num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
97
+ )
98
+ else:
99
+ raise NotImplementedError(beta_schedule)
100
+ assert betas.shape == (num_diffusion_timesteps,)
101
+ return betas
102
+
103
+
104
+ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
105
+ """
106
+ Get a pre-defined beta schedule for the given name.
107
+ The beta schedule library consists of beta schedules which remain similar
108
+ in the limit of num_diffusion_timesteps.
109
+ Beta schedules may be added, but should not be removed or changed once
110
+ they are committed to maintain backwards compatibility.
111
+ """
112
+ if schedule_name == "linear":
113
+ # Linear schedule from Ho et al, extended to work for any number of
114
+ # diffusion steps.
115
+ scale = 1000 / num_diffusion_timesteps
116
+ return get_beta_schedule(
117
+ "linear",
118
+ beta_start=scale * 0.0001,
119
+ beta_end=scale * 0.02,
120
+ num_diffusion_timesteps=num_diffusion_timesteps,
121
+ )
122
+ elif schedule_name == "squaredcos_cap_v2":
123
+ return betas_for_alpha_bar(
124
+ num_diffusion_timesteps,
125
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
126
+ )
127
+ else:
128
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
129
+
130
+
131
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
132
+ """
133
+ Create a beta schedule that discretizes the given alpha_t_bar function,
134
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
135
+ :param num_diffusion_timesteps: the number of betas to produce.
136
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
137
+ produces the cumulative product of (1-beta) up to that
138
+ part of the diffusion process.
139
+ :param max_beta: the maximum beta to use; use values lower than 1 to
140
+ prevent singularities.
141
+ """
142
+ betas = []
143
+ for i in range(num_diffusion_timesteps):
144
+ t1 = i / num_diffusion_timesteps
145
+ t2 = (i + 1) / num_diffusion_timesteps
146
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
147
+ return np.array(betas)
148
+
149
+
150
+ class GaussianDiffusion:
151
+ """
152
+ Utilities for training and sampling diffusion models.
153
+ Original ported from this codebase:
154
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
155
+ :param betas: a 1-D numpy array of betas for each diffusion timestep,
156
+ starting at T and going to 1.
157
+ """
158
+
159
+ def __init__(
160
+ self,
161
+ *,
162
+ betas,
163
+ model_mean_type,
164
+ model_var_type,
165
+ loss_type
166
+ ):
167
+
168
+ self.model_mean_type = model_mean_type
169
+ self.model_var_type = model_var_type
170
+ self.loss_type = loss_type
171
+
172
+ # Use float64 for accuracy.
173
+ betas = np.array(betas, dtype=np.float64)
174
+ self.betas = betas
175
+ assert len(betas.shape) == 1, "betas must be 1-D"
176
+ assert (betas > 0).all() and (betas <= 1).all()
177
+
178
+ self.num_timesteps = int(betas.shape[0])
179
+
180
+ alphas = 1.0 - betas
181
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
182
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
183
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
184
+ assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
185
+
186
+ # calculations for diffusion q(x_t | x_{t-1}) and others
187
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
188
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
189
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
190
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
191
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
192
+
193
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
194
+ self.posterior_variance = (
195
+ betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
196
+ )
197
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
198
+ self.posterior_log_variance_clipped = np.log(
199
+ np.append(self.posterior_variance[1], self.posterior_variance[1:])
200
+ ) if len(self.posterior_variance) > 1 else np.array([])
201
+
202
+ self.posterior_mean_coef1 = (
203
+ betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
204
+ )
205
+ self.posterior_mean_coef2 = (
206
+ (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
207
+ )
208
+
209
+ def q_mean_variance(self, x_start, t):
210
+ """
211
+ Get the distribution q(x_t | x_0).
212
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
213
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
214
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
215
+ """
216
+ mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
217
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
218
+ log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
219
+ return mean, variance, log_variance
220
+
221
+ def q_sample(self, x_start, t, noise=None):
222
+ """
223
+ Diffuse the data for a given number of diffusion steps.
224
+ In other words, sample from q(x_t | x_0).
225
+ :param x_start: the initial data batch.
226
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
227
+ :param noise: if specified, the split-out normal noise.
228
+ :return: A noisy version of x_start.
229
+ """
230
+ if noise is None:
231
+ noise = th.randn_like(x_start)
232
+ assert noise.shape == x_start.shape
233
+ return (
234
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
235
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
236
+ )
237
+
238
+ def q_posterior_mean_variance(self, x_start, x_t, t):
239
+ """
240
+ Compute the mean and variance of the diffusion posterior:
241
+ q(x_{t-1} | x_t, x_0)
242
+ """
243
+ assert x_start.shape == x_t.shape
244
+ posterior_mean = (
245
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
246
+ + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
247
+ )
248
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
249
+ posterior_log_variance_clipped = _extract_into_tensor(
250
+ self.posterior_log_variance_clipped, t, x_t.shape
251
+ )
252
+ assert (
253
+ posterior_mean.shape[0]
254
+ == posterior_variance.shape[0]
255
+ == posterior_log_variance_clipped.shape[0]
256
+ == x_start.shape[0]
257
+ )
258
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
259
+
260
+ def q_posterior_mean_variance_dual(self, x_start, x_t, t):
261
+ """
262
+ Compute the posterior mean and variance for each modality:
263
+ q(x_{t-1} | x_t, x_0)
264
+ Inputs:
265
+ x_start: tuple (x_v_start, x_a_start)
266
+ x_t: tuple (x_v_t, x_a_t)
267
+ t: Tensor of shape [B]
268
+ Outputs:
269
+ posterior_mean: (mean_v, mean_a)
270
+ posterior_variance: (var_v, var_a)
271
+ posterior_log_variance_clipped: (logvar_v, logvar_a)
272
+ """
273
+ x_v_start, x_a_start = x_start
274
+ x_v_t, x_a_t = x_t
275
+
276
+ def single_modality_q(x_start_i, x_t_i):
277
+ assert x_start_i.shape == x_t_i.shape
278
+ posterior_mean = (
279
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t_i.shape) * x_start_i
280
+ + _extract_into_tensor(self.posterior_mean_coef2, t, x_t_i.shape) * x_t_i
281
+ )
282
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t_i.shape)
283
+ posterior_log_variance_clipped = _extract_into_tensor(
284
+ self.posterior_log_variance_clipped, t, x_t_i.shape
285
+ )
286
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
287
+
288
+ mean_v, var_v, logvar_v = single_modality_q(x_v_start, x_v_t)
289
+ mean_a, var_a, logvar_a = single_modality_q(x_a_start, x_a_t)
290
+
291
+ return (mean_v, mean_a), (var_v, var_a), (logvar_v, logvar_a)
292
+
293
+ def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None):
294
+ """
295
+ Dual-modality version.
296
+ x: (x_v_t, x_a_t)
297
+ model: takes (x_v_t, x_a_t, t, **model_kwargs)
298
+ returns: out_v, out_a: dicts with 'mean', 'variance', 'log_variance', 'pred_xstart'
299
+ """
300
+ if model_kwargs is None:
301
+ model_kwargs = {}
302
+
303
+ x_v, x_a = x
304
+ B, C_v = x_v.shape[:2]
305
+ B, C_a = x_a.shape[:2]
306
+ assert t.shape == (B,)
307
+
308
+ # Call model once to get both outputs
309
+ model_output_v, model_output_a = model(x_v, x_a, t, **model_kwargs)
310
+
311
+ # Helper function for one modality
312
+ def process_modality(x_t, model_output, C):
313
+ if isinstance(model_output, tuple):
314
+ model_output, _ = model_output # drop extra output if any
315
+
316
+ if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
317
+ assert model_output.shape == (B, C * 2, *x_t.shape[2:])
318
+ model_output, model_var_values = th.split(model_output, C, dim=1)
319
+ min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
320
+ max_log = _extract_into_tensor(np.log(self.betas), t, x_t.shape)
321
+ frac = (model_var_values + 1) / 2
322
+ model_log_variance = frac * max_log + (1 - frac) * min_log
323
+ model_variance = th.exp(model_log_variance)
324
+ else:
325
+ model_variance_, model_log_variance_ = {
326
+ ModelVarType.FIXED_LARGE: (
327
+ np.append(self.posterior_variance[1], self.betas[1:]),
328
+ np.log(np.append(self.posterior_variance[1], self.betas[1:])),
329
+ ),
330
+ ModelVarType.FIXED_SMALL: (
331
+ self.posterior_variance,
332
+ self.posterior_log_variance_clipped,
333
+ ),
334
+ }[self.model_var_type]
335
+ model_variance = _extract_into_tensor(model_variance_, t, x_t.shape)
336
+ model_log_variance = _extract_into_tensor(model_log_variance_, t, x_t.shape)
337
+
338
+ def process_xstart(x):
339
+ if denoised_fn is not None:
340
+ x = denoised_fn(x)
341
+ if clip_denoised:
342
+ x = x.clamp(-1, 1)
343
+ return x
344
+
345
+ if self.model_mean_type == ModelMeanType.START_X:
346
+ pred_xstart = process_xstart(model_output)
347
+ else:
348
+ pred_xstart = process_xstart(
349
+ self._predict_xstart_from_eps(x_t=x_t, t=t, eps=model_output)
350
+ )
351
+
352
+ model_mean, _, _ = self.q_posterior_mean_variance(
353
+ x_start=pred_xstart, x_t=x_t, t=t
354
+ )
355
+
356
+ return {
357
+ "mean": model_mean,
358
+ "variance": model_variance,
359
+ "log_variance": model_log_variance,
360
+ "pred_xstart": pred_xstart,
361
+ }
362
+
363
+ out_v = process_modality(x_v, model_output_v, C_v)
364
+ out_a = process_modality(x_a, model_output_a, C_a)
365
+
366
+ return out_v, out_a
367
+
368
+
369
+ def _predict_xstart_from_eps(self, x_t, t, eps):
370
+ assert x_t.shape == eps.shape
371
+ return (
372
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
373
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
374
+ )
375
+
376
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
377
+ return (
378
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
379
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
380
+
381
+ def condition_mean(
382
+ self,
383
+ cond_fn, # callable(x_v, x_a, t, **model_kwargs) -> (grad_v, grad_a)
384
+ p_mean_var_v, # dict for video: contains 'mean', 'variance'
385
+ p_mean_var_a, # dict for audio
386
+ x_v, x_a, # x_t for video/audio
387
+ t,
388
+ model_kwargs=None,
389
+ ):
390
+ """
391
+ Compute conditional mean separately for each modality:
392
+ new_mean = mean + variance * ∇ log p(y|x_t)
393
+ """
394
+ if model_kwargs is None:
395
+ model_kwargs = {}
396
+
397
+ # cond_fn must return (grad_v, grad_a)
398
+ grad_v, grad_a = cond_fn(x_v, x_a, t, **model_kwargs)
399
+
400
+ new_mean_v = p_mean_var_v["mean"].float() + p_mean_var_v["variance"] * grad_v.float()
401
+ new_mean_a = p_mean_var_a["mean"].float() + p_mean_var_a["variance"] * grad_a.float()
402
+
403
+ return new_mean_v, new_mean_a
404
+
405
+ def p_sample(
406
+ self,
407
+ model,
408
+ x_v,
409
+ x_a,
410
+ t,
411
+ clip_denoised=True,
412
+ denoised_fn=None,
413
+ cond_fn=None,
414
+ model_kwargs=None,
415
+ ):
416
+ """
417
+ Sample x_{t-1} from the model at the given timestep.
418
+ :param model: the model to sample from.
419
+ :param x: the current tensor at x_{t-1}.
420
+ :param t: the value of t, starting at 0 for the first diffusion step.
421
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
422
+ :param denoised_fn: if not None, a function which applies to the
423
+ x_start prediction before it is used to sample.
424
+ :param cond_fn: if not None, this is a gradient function that acts
425
+ similarly to the model.
426
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
427
+ pass to the model. This can be used for conditioning.
428
+ :return: a dict containing the following keys:
429
+ - 'sample': a random sample from the model.
430
+ - 'pred_xstart': a prediction of x_0.
431
+ """
432
+ # out = self.p_mean_variance(
433
+ # model,
434
+ # x,
435
+ # t,
436
+ # clip_denoised=clip_denoised,
437
+ # denoised_fn=denoised_fn,
438
+ # model_kwargs=model_kwargs,
439
+ # )
440
+ out_v, out_a = self.p_mean_variance(
441
+ model=model,
442
+ x=(x_v, x_a),
443
+ t=t,
444
+ clip_denoised=clip_denoised,
445
+ denoised_fn=denoised_fn,
446
+ model_kwargs=model_kwargs,
447
+ )
448
+ noise_v = th.randn_like(x_v)
449
+ noise_a = th.randn_like(x_a)
450
+
451
+ nonzero_mask_v = (
452
+ (t != 0).float().view(-1, *([1] * (len(x_v.shape) - 1)))
453
+ ) # no noise when t == 0
454
+ nonzero_mask_a = (
455
+ (t != 0).float().view(-1, *([1] * (len(x_a.shape) - 1)))
456
+ )
457
+
458
+ if cond_fn is not None:
459
+
460
+ out_v["mean"], out_a["mean"] = condition_mean(cond_fn, out_v, out_a, x_v, x_a, t, model_kwargs=model_kwargs)
461
+ sample_v = out_v["mean"] + nonzero_mask_v * th.exp(0.5 * out_v["log_variance"]) * noise_v
462
+ sample_a = out_a["mean"] + nonzero_mask_a * th.exp(0.5 * out_a["log_variance"]) * noise_a
463
+ return {"sample_v": sample_v, "sample_a": sample_a, "pred_xstart_v": out_v["pred_xstart"], "pred_xstart_a": out_a["pred_xstart"]}
464
+
465
+ def p_sample_loop(
466
+ self,
467
+ model,
468
+ shape_v,
469
+ shape_a,
470
+ noise_v=None,
471
+ noise_a=None,
472
+ clip_denoised=True,
473
+ denoised_fn=None,
474
+ cond_fn=None,
475
+ model_kwargs=None,
476
+ device=None,
477
+ progress=False,
478
+ ):
479
+ """
480
+ Generate samples from the model.
481
+ :param model: the model module.
482
+ :param shape: the shape of the samples, (N, C, H, W).
483
+ :param noise: if specified, the noise from the encoder to sample.
484
+ Should be of the same shape as `shape`.
485
+ :param clip_denoised: if True, clip x_start predictions to [-1, 1].
486
+ :param denoised_fn: if not None, a function which applies to the
487
+ x_start prediction before it is used to sample.
488
+ :param cond_fn: if not None, this is a gradient function that acts
489
+ similarly to the model.
490
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
491
+ pass to the model. This can be used for conditioning.
492
+ :param device: if specified, the device to create the samples on.
493
+ If not specified, use a model parameter's device.
494
+ :param progress: if True, show a tqdm progress bar.
495
+ :return: a non-differentiable batch of samples.
496
+ """
497
+ final = None
498
+ for sample in self.p_sample_loop_progressive(
499
+ model,
500
+ shape_v,
501
+ shape_a,
502
+ noise_v=noise_v,
503
+ noise_a=noise_a,
504
+ clip_denoised=clip_denoised,
505
+ denoised_fn=denoised_fn,
506
+ cond_fn=cond_fn,
507
+ model_kwargs=model_kwargs,
508
+ device=device,
509
+ progress=progress,
510
+ ):
511
+ final = sample
512
+ return final["sample_v"], final["sample_a"]
513
+
514
+ def p_sample_loop_progressive(
515
+ self,
516
+ model,
517
+ shape_v,
518
+ shape_a,
519
+ noise_v=None,
520
+ noise_a=None,
521
+ clip_denoised=True,
522
+ denoised_fn=None,
523
+ cond_fn=None,
524
+ model_kwargs=None,
525
+ device=None,
526
+ progress=False,
527
+ ):
528
+ """
529
+ Generate samples from the model and yield intermediate samples from
530
+ each timestep of diffusion.
531
+ Arguments are the same as p_sample_loop().
532
+ Returns a generator over dicts, where each dict is the return value of
533
+ p_sample().
534
+ """
535
+ if device is None:
536
+ device = next(model.parameters()).device
537
+ assert isinstance(shape_v, (tuple, list))
538
+ assert isinstance(shape_a, (tuple, list))
539
+
540
+ if noise_v is not None:
541
+ img = noise_v
542
+ else:
543
+ img = th.randn(*shape_v, device=device)
544
+ if noise_a is not None:
545
+ audio = noise_a
546
+ else:
547
+ audio = th.randn(*shape_a, device=device)
548
+
549
+ indices = list(range(self.num_timesteps))[::-1]
550
+
551
+ if progress:
552
+ # Lazy import so that we don't depend on tqdm.
553
+ from tqdm.auto import tqdm
554
+
555
+ indices = tqdm(indices)
556
+
557
+ for i in indices:
558
+ t = th.tensor([i] * shape_v[0], device=device)
559
+ with th.no_grad():
560
+ #{"sample_v": sample_v, "sample_a": sample_a, "pred_xstart_v": out_v["pred_xstart"], "pred_xstart_a": out_a["pred_xstart"]}
561
+ out = self.p_sample(
562
+ model,
563
+ img,
564
+ audio,
565
+ t,
566
+ clip_denoised=clip_denoised,
567
+ denoised_fn=denoised_fn,
568
+ cond_fn=cond_fn,
569
+ model_kwargs=model_kwargs,
570
+ )
571
+ yield out
572
+ img = out["sample_v"]
573
+ audio = out["sample_a"]
574
+
575
+ def ddim_sample(
576
+ self,
577
+ model,
578
+ x,
579
+ t,
580
+ clip_denoised=True,
581
+ denoised_fn=None,
582
+ cond_fn=None,
583
+ model_kwargs=None,
584
+ eta=0.0,
585
+ ):
586
+ """
587
+ Sample x_{t-1} from the model using DDIM.
588
+ Same usage as p_sample().
589
+ """
590
+ out = self.p_mean_variance(
591
+ model,
592
+ x,
593
+ t,
594
+ clip_denoised=clip_denoised,
595
+ denoised_fn=denoised_fn,
596
+ model_kwargs=model_kwargs,
597
+ )
598
+ if cond_fn is not None:
599
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
600
+
601
+ # Usually our model outputs epsilon, but we re-derive it
602
+ # in case we used x_start or x_prev prediction.
603
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
604
+
605
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
606
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
607
+ sigma = (
608
+ eta
609
+ * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
610
+ * th.sqrt(1 - alpha_bar / alpha_bar_prev)
611
+ )
612
+ # Equation 12.
613
+ noise = th.randn_like(x)
614
+ mean_pred = (
615
+ out["pred_xstart"] * th.sqrt(alpha_bar_prev)
616
+ + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
617
+ )
618
+ nonzero_mask = (
619
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
620
+ ) # no noise when t == 0
621
+ sample = mean_pred + nonzero_mask * sigma * noise
622
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
623
+
624
+ def ddim_reverse_sample(
625
+ self,
626
+ model,
627
+ x,
628
+ t,
629
+ clip_denoised=True,
630
+ denoised_fn=None,
631
+ cond_fn=None,
632
+ model_kwargs=None,
633
+ eta=0.0,
634
+ ):
635
+ """
636
+ Sample x_{t+1} from the model using DDIM reverse ODE.
637
+ """
638
+ assert eta == 0.0, "Reverse ODE only for deterministic path"
639
+ out = self.p_mean_variance(
640
+ model,
641
+ x,
642
+ t,
643
+ clip_denoised=clip_denoised,
644
+ denoised_fn=denoised_fn,
645
+ model_kwargs=model_kwargs,
646
+ )
647
+ if cond_fn is not None:
648
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
649
+ # Usually our model outputs epsilon, but we re-derive it
650
+ # in case we used x_start or x_prev prediction.
651
+ eps = (
652
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
653
+ - out["pred_xstart"]
654
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
655
+ alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
656
+
657
+ # Equation 12. reversed
658
+ mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
659
+
660
+ return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
661
+
662
+ def ddim_sample_loop(
663
+ self,
664
+ model,
665
+ shape,
666
+ noise=None,
667
+ clip_denoised=True,
668
+ denoised_fn=None,
669
+ cond_fn=None,
670
+ model_kwargs=None,
671
+ device=None,
672
+ progress=False,
673
+ eta=0.0,
674
+ ):
675
+ """
676
+ Generate samples from the model using DDIM.
677
+ Same usage as p_sample_loop().
678
+ """
679
+ final = None
680
+ for sample in self.ddim_sample_loop_progressive(
681
+ model,
682
+ shape,
683
+ noise=noise,
684
+ clip_denoised=clip_denoised,
685
+ denoised_fn=denoised_fn,
686
+ cond_fn=cond_fn,
687
+ model_kwargs=model_kwargs,
688
+ device=device,
689
+ progress=progress,
690
+ eta=eta,
691
+ ):
692
+ final = sample
693
+ return final["sample"]
694
+
695
+ def ddim_sample_loop_progressive(
696
+ self,
697
+ model,
698
+ shape,
699
+ noise=None,
700
+ clip_denoised=True,
701
+ denoised_fn=None,
702
+ cond_fn=None,
703
+ model_kwargs=None,
704
+ device=None,
705
+ progress=False,
706
+ eta=0.0,
707
+ ):
708
+ """
709
+ Use DDIM to sample from the model and yield intermediate samples from
710
+ each timestep of DDIM.
711
+ Same usage as p_sample_loop_progressive().
712
+ """
713
+ if device is None:
714
+ device = next(model.parameters()).device
715
+ assert isinstance(shape, (tuple, list))
716
+ if noise is not None:
717
+ img = noise
718
+ else:
719
+ img = th.randn(*shape, device=device)
720
+ indices = list(range(self.num_timesteps))[::-1]
721
+
722
+ if progress:
723
+ # Lazy import so that we don't depend on tqdm.
724
+ from tqdm.auto import tqdm
725
+
726
+ indices = tqdm(indices)
727
+
728
+ for i in indices:
729
+ t = th.tensor([i] * shape[0], device=device)
730
+ with th.no_grad():
731
+ out = self.ddim_sample(
732
+ model,
733
+ img,
734
+ t,
735
+ clip_denoised=clip_denoised,
736
+ denoised_fn=denoised_fn,
737
+ cond_fn=cond_fn,
738
+ model_kwargs=model_kwargs,
739
+ eta=eta,
740
+ )
741
+ yield out
742
+ img = out["sample"]
743
+
744
+ def _vb_terms_bpd(
745
+ self, model, x_v_start, x_a_start, x_v_t, x_a_t, t, clip_denoised=True, model_kwargs=None
746
+ ):
747
+ """
748
+ Dual-modality VB loss.
749
+ """
750
+
751
+ # --- True posterior
752
+ (true_mean_v, true_mean_a), _, (logvar_v, logvar_a) = self.q_posterior_mean_variance_dual(
753
+ x_start=(x_v_start, x_a_start),
754
+ x_t=(x_v_t, x_a_t),
755
+ t=t,
756
+ )
757
+
758
+ # --- Model prediction
759
+ out_v, out_a = self.p_mean_variance(
760
+ model=model,
761
+ x=(x_v_t, x_a_t),
762
+ t=t,
763
+ clip_denoised=clip_denoised,
764
+ model_kwargs=model_kwargs,
765
+ )
766
+
767
+ # --- KL loss
768
+ kl_v = normal_kl(true_mean_v, logvar_v, out_v["mean"], out_v["log_variance"])
769
+ kl_a = normal_kl(true_mean_a, logvar_a, out_a["mean"], out_a["log_variance"])
770
+ kl_v = mean_flat(kl_v) / np.log(2.0)
771
+ kl_a = mean_flat(kl_a) / np.log(2.0)
772
+
773
+ # --- NLL loss (only at t=0)
774
+ decoder_nll_v = -discretized_gaussian_log_likelihood(
775
+ x_v_start, means=out_v["mean"], log_scales=0.5 * out_v["log_variance"]
776
+ )
777
+ decoder_nll_v = mean_flat(decoder_nll_v) / np.log(2.0)
778
+
779
+ decoder_nll_a = -discretized_gaussian_log_likelihood(
780
+ x_a_start, means=out_a["mean"], log_scales=0.5 * out_a["log_variance"]
781
+ )
782
+ decoder_nll_a = mean_flat(decoder_nll_a) / np.log(2.0)
783
+
784
+ # --- Final VB loss
785
+ output_v = th.where((t == 0), decoder_nll_v, kl_v)
786
+ output_a = th.where((t == 0), decoder_nll_a, kl_a)
787
+
788
+ return {
789
+ "output_v": output_v,
790
+ "output_a": output_a,
791
+ "pred_xstart": (out_v["pred_xstart"], out_a["pred_xstart"]),
792
+ }
793
+
794
+ def training_losses(self, model, x_v_start, x_a_start, t, model_kwargs=None, noise_v=None, noise_a=None):
795
+ """
796
+ Compute training losses for a single timestep.
797
+ :param model: the model to evaluate loss on.
798
+ :param x_start: the [N x C x ...] tensor of inputs.
799
+ :param t: a batch of timestep indices.
800
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
801
+ pass to the model. This can be used for conditioning.
802
+ :param noise: if specified, the specific Gaussian noise to try to remove.
803
+ :return: a dict with the key "loss" containing a tensor of shape [N].
804
+ Some mean or variance settings may also have other keys.
805
+ """
806
+ if model_kwargs is None:
807
+ model_kwargs = {}
808
+ if noise_v is None:
809
+ noise_v = th.randn_like(x_v_start)
810
+ x_v_t = self.q_sample(x_v_start, t, noise=noise_v)
811
+ if noise_a is None:
812
+ noise_a = th.randn_like(x_a_start)
813
+ x_a_t = self.q_sample(x_a_start, t, noise=noise_a)
814
+
815
+ terms = {}
816
+
817
+ if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
818
+ vb_terms = self._vb_terms_bpd(
819
+ model=model,
820
+ x_v_start=x_v_start,
821
+ x_a_start=x_a_start,
822
+ x_v_t=x_v_t,
823
+ x_a_t=x_a_t,
824
+ t=t,
825
+ clip_denoised=False,
826
+ model_kwargs=model_kwargs,
827
+ )
828
+ terms["vb_v"] = vb_terms["output_v"]
829
+ terms["vb_a"] = vb_terms["output_a"]
830
+ terms["loss"] = vb_terms["output_v"] + vb_terms["output_a"]
831
+ if self.loss_type == LossType.RESCALED_KL:
832
+ terms["loss"] *= self.num_timesteps
833
+
834
+ elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
835
+ model_output_v, model_output_a = model(x_v_t, x_a_t, t, **model_kwargs)
836
+ if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
837
+ B, C_v = x_v_t.shape[:2]
838
+ B, C_a = x_a_t.shape[:2]
839
+
840
+ model_output_v, model_var_v = th.split(model_output_v, C_v, dim=1)
841
+ model_output_a, model_var_a = th.split(model_output_a, C_a, dim=1)
842
+
843
+ frozen_out_v = th.cat([model_output_v.detach(), model_var_v], dim=1)
844
+ frozen_out_a = th.cat([model_output_a.detach(), model_var_a], dim=1)
845
+
846
+ frozen_model = lambda *args, **kwargs: (frozen_out_v, frozen_out_a)
847
+
848
+ vb_output = self._vb_terms_bpd(
849
+ model=frozen_model,
850
+ x_v_start=x_v_start,
851
+ x_a_start=x_a_start,
852
+ x_v_t=x_v_t,
853
+ x_a_t=x_a_t,
854
+ t=t,
855
+ clip_denoised=False,
856
+ )
857
+
858
+ terms["vb_v"] = vb_output["output_v"]
859
+ terms["vb_a"] = vb_output["output_a"]
860
+
861
+ # === MSE Loss ===
862
+ def process_mse(modality, x_start, x_t, model_output, noise):
863
+ target = {
864
+ ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance_dual(
865
+ x_start=(x_v_start, x_a_start),
866
+ x_t=(x_v_t, x_a_t),
867
+ t=t,
868
+ )[0][0 if modality == "v" else 1],
869
+ ModelMeanType.START_X: x_start,
870
+ ModelMeanType.EPSILON: noise,
871
+ }[self.model_mean_type]
872
+
873
+ assert model_output.shape == target.shape == x_start.shape
874
+ terms[f"mse_{modality}"] = mean_flat((target - model_output) ** 2)
875
+
876
+ process_mse("v", x_v_start, x_v_t, model_output_v, noise_v)
877
+ process_mse("a", x_a_start, x_a_t, model_output_a, noise_a)
878
+
879
+ if "vb_v" in terms and "vb_a" in terms:
880
+ terms["vb"] = terms["vb_v"] + terms["vb_a"]
881
+ if self.loss_type == LossType.RESCALED_MSE:
882
+ terms["vb"] *= self.num_timesteps / 1000.0
883
+
884
+ terms["loss"] = terms["mse_v"] + terms["mse_a"]
885
+ if "vb" in terms:
886
+ terms["loss"] += terms["vb"]
887
+
888
+
889
+ return terms
890
+
891
+ def _prior_bpd(self, x_start):
892
+ """
893
+ Get the prior KL term for the variational lower-bound, measured in
894
+ bits-per-dim.
895
+ This term can't be optimized, as it only depends on the encoder.
896
+ :param x_start: the [N x C x ...] tensor of inputs.
897
+ :return: a batch of [N] KL values (in bits), one per batch element.
898
+ """
899
+ batch_size = x_start.shape[0]
900
+ t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
901
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
902
+ kl_prior = normal_kl(
903
+ mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
904
+ )
905
+ return mean_flat(kl_prior) / np.log(2.0)
906
+
907
+ def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
908
+ """
909
+ Compute the entire variational lower-bound, measured in bits-per-dim,
910
+ as well as other related quantities.
911
+ :param model: the model to evaluate loss on.
912
+ :param x_start: the [N x C x ...] tensor of inputs.
913
+ :param clip_denoised: if True, clip denoised samples.
914
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
915
+ pass to the model. This can be used for conditioning.
916
+ :return: a dict containing the following keys:
917
+ - total_bpd: the total variational lower-bound, per batch element.
918
+ - prior_bpd: the prior term in the lower-bound.
919
+ - vb: an [N x T] tensor of terms in the lower-bound.
920
+ - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
921
+ - mse: an [N x T] tensor of epsilon MSEs for each timestep.
922
+ """
923
+ device = x_start.device
924
+ batch_size = x_start.shape[0]
925
+
926
+ vb = []
927
+ xstart_mse = []
928
+ mse = []
929
+ for t in list(range(self.num_timesteps))[::-1]:
930
+ t_batch = th.tensor([t] * batch_size, device=device)
931
+ noise = th.randn_like(x_start)
932
+ x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
933
+ # Calculate VLB term at the current timestep
934
+ with th.no_grad():
935
+ out = self._vb_terms_bpd(
936
+ model,
937
+ x_start=x_start,
938
+ x_t=x_t,
939
+ t=t_batch,
940
+ clip_denoised=clip_denoised,
941
+ model_kwargs=model_kwargs,
942
+ )
943
+ vb.append(out["output"])
944
+ xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
945
+ eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
946
+ mse.append(mean_flat((eps - noise) ** 2))
947
+
948
+ vb = th.stack(vb, dim=1)
949
+ xstart_mse = th.stack(xstart_mse, dim=1)
950
+ mse = th.stack(mse, dim=1)
951
+
952
+ prior_bpd = self._prior_bpd(x_start)
953
+ total_bpd = vb.sum(dim=1) + prior_bpd
954
+ return {
955
+ "total_bpd": total_bpd,
956
+ "prior_bpd": prior_bpd,
957
+ "vb": vb,
958
+ "xstart_mse": xstart_mse,
959
+ "mse": mse,
960
+ }
961
+
962
+
963
+ def _extract_into_tensor(arr, timesteps, broadcast_shape):
964
+ """
965
+ Extract values from a 1-D numpy array for a batch of indices.
966
+ :param arr: the 1-D numpy array.
967
+ :param timesteps: a tensor of indices into the array to extract.
968
+ :param broadcast_shape: a larger shape of K dimensions with the batch
969
+ dimension equal to the length of timesteps.
970
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
971
+ """
972
+ res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
973
+ while len(res.shape) < len(broadcast_shape):
974
+ res = res[..., None]
975
+ return res + th.zeros(broadcast_shape, device=timesteps.device)
diffusion/respace.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch as th
3
+
4
+ from .gaussian_diffusion import GaussianDiffusion
5
+
6
+
7
+ def space_timesteps(num_timesteps, section_counts):
8
+ """
9
+ Create a list of timesteps to use from an original diffusion process,
10
+ given the number of timesteps we want to take from equally-sized portions
11
+ of the original process.
12
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
13
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
14
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
15
+ If the stride is a string starting with "ddim", then the fixed striding
16
+ from the DDIM paper is used, and only one section is allowed.
17
+ :param num_timesteps: the number of diffusion steps in the original
18
+ process to divide up.
19
+ :param section_counts: either a list of numbers, or a string containing
20
+ comma-separated numbers, indicating the step count
21
+ per section. As a special case, use "ddimN" where N
22
+ is a number of steps to use the striding from the
23
+ DDIM paper.
24
+ :return: a set of diffusion steps from the original process to use.
25
+ """
26
+ if isinstance(section_counts, str):
27
+ if section_counts.startswith("ddim"):
28
+ desired_count = int(section_counts[len("ddim") :])
29
+ for i in range(1, num_timesteps):
30
+ if len(range(0, num_timesteps, i)) == desired_count:
31
+ return set(range(0, num_timesteps, i))
32
+ raise ValueError(
33
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
34
+ )
35
+ section_counts = [int(x) for x in section_counts.split(",")]
36
+ size_per = num_timesteps // len(section_counts)
37
+ extra = num_timesteps % len(section_counts)
38
+ start_idx = 0
39
+ all_steps = []
40
+ for i, section_count in enumerate(section_counts):
41
+ size = size_per + (1 if i < extra else 0)
42
+ if size < section_count:
43
+ raise ValueError(
44
+ f"cannot divide section of {size} steps into {section_count}"
45
+ )
46
+ if section_count <= 1:
47
+ frac_stride = 1
48
+ else:
49
+ frac_stride = (size - 1) / (section_count - 1)
50
+ cur_idx = 0.0
51
+ taken_steps = []
52
+ for _ in range(section_count):
53
+ taken_steps.append(start_idx + round(cur_idx))
54
+ cur_idx += frac_stride
55
+ all_steps += taken_steps
56
+ start_idx += size
57
+ return set(all_steps)
58
+
59
+
60
+ class SpacedDiffusion(GaussianDiffusion):
61
+ """
62
+ A diffusion process which can skip steps in a base diffusion process.
63
+ :param use_timesteps: a collection (sequence or set) of timesteps from the
64
+ original diffusion process to retain.
65
+ :param kwargs: the kwargs to create the base diffusion process.
66
+ """
67
+
68
+ def __init__(self, use_timesteps, dual, **kwargs):
69
+ self.use_timesteps = set(use_timesteps)
70
+ self.timestep_map = []
71
+ self.original_num_steps = len(kwargs["betas"])
72
+ self.dual = dual
73
+
74
+ base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
75
+ last_alpha_cumprod = 1.0
76
+ new_betas = []
77
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
78
+ if i in self.use_timesteps:
79
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
80
+ last_alpha_cumprod = alpha_cumprod
81
+ self.timestep_map.append(i)
82
+ kwargs["betas"] = np.array(new_betas)
83
+ super().__init__(**kwargs)
84
+
85
+ def p_mean_variance(
86
+ self, model, *args, **kwargs
87
+ ): # pylint: disable=signature-differs
88
+ return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
89
+
90
+ def training_losses(
91
+ self, model, *args, **kwargs
92
+ ): # pylint: disable=signature-differs
93
+ return super().training_losses(self._wrap_model(model), *args, **kwargs)
94
+
95
+ def condition_mean(self, cond_fn, *args, **kwargs):
96
+ return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
97
+
98
+ def condition_score(self, cond_fn, *args, **kwargs):
99
+ return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
100
+
101
+ def _wrap_model(self, model):
102
+ if isinstance(model, _WrappedModel):
103
+ return model
104
+ return _WrappedModel(
105
+ model, self.timestep_map, self.original_num_steps, self.dual
106
+ )
107
+
108
+ def _scale_timesteps(self, t):
109
+ # Scaling is done by the wrapped model.
110
+ return t
111
+
112
+ class _WrappedModel:
113
+ def __init__(self, model, timestep_map, original_num_steps):
114
+ self.model = model
115
+ self.timestep_map = timestep_map
116
+ # self.rescale_timesteps = rescale_timesteps
117
+ self.original_num_steps = original_num_steps
118
+
119
+ def __call__(self, x, ts, **kwargs):
120
+ map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
121
+ new_ts = map_tensor[ts]
122
+ # if self.rescale_timesteps:
123
+ # new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
124
+ return self.model(x, new_ts, **kwargs)
125
+
diffusion/respace_dual.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # Modified from OpenAI's diffusion repos
8
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
9
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
10
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
11
+
12
+ import numpy as np
13
+ import torch as th
14
+
15
+ from .gaussian_diffusion_dual import GaussianDiffusion
16
+
17
+
18
+ def space_timesteps(num_timesteps, section_counts):
19
+ """
20
+ Create a list of timesteps to use from an original diffusion process,
21
+ given the number of timesteps we want to take from equally-sized portions
22
+ of the original process.
23
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
24
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
25
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
26
+ If the stride is a string starting with "ddim", then the fixed striding
27
+ from the DDIM paper is used, and only one section is allowed.
28
+ :param num_timesteps: the number of diffusion steps in the original
29
+ process to divide up.
30
+ :param section_counts: either a list of numbers, or a string containing
31
+ comma-separated numbers, indicating the step count
32
+ per section. As a special case, use "ddimN" where N
33
+ is a number of steps to use the striding from the
34
+ DDIM paper.
35
+ :return: a set of diffusion steps from the original process to use.
36
+ """
37
+ if isinstance(section_counts, str):
38
+ if section_counts.startswith("ddim"):
39
+ desired_count = int(section_counts[len("ddim") :])
40
+ for i in range(1, num_timesteps):
41
+ if len(range(0, num_timesteps, i)) == desired_count:
42
+ return set(range(0, num_timesteps, i))
43
+ raise ValueError(
44
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
45
+ )
46
+ section_counts = [int(x) for x in section_counts.split(",")]
47
+ size_per = num_timesteps // len(section_counts)
48
+ extra = num_timesteps % len(section_counts)
49
+ start_idx = 0
50
+ all_steps = []
51
+ for i, section_count in enumerate(section_counts):
52
+ size = size_per + (1 if i < extra else 0)
53
+ if size < section_count:
54
+ raise ValueError(
55
+ f"cannot divide section of {size} steps into {section_count}"
56
+ )
57
+ if section_count <= 1:
58
+ frac_stride = 1
59
+ else:
60
+ frac_stride = (size - 1) / (section_count - 1)
61
+ cur_idx = 0.0
62
+ taken_steps = []
63
+ for _ in range(section_count):
64
+ taken_steps.append(start_idx + round(cur_idx))
65
+ cur_idx += frac_stride
66
+ all_steps += taken_steps
67
+ start_idx += size
68
+ return set(all_steps)
69
+
70
+
71
+ class SpacedDiffusion(GaussianDiffusion):
72
+ """
73
+ A diffusion process which can skip steps in a base diffusion process.
74
+ :param use_timesteps: a collection (sequence or set) of timesteps from the
75
+ original diffusion process to retain.
76
+ :param kwargs: the kwargs to create the base diffusion process.
77
+ """
78
+
79
+ def __init__(self, use_timesteps, **kwargs):
80
+ self.use_timesteps = set(use_timesteps)
81
+ self.timestep_map = []
82
+ self.original_num_steps = len(kwargs["betas"])
83
+
84
+ base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
85
+ last_alpha_cumprod = 1.0
86
+ new_betas = []
87
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
88
+ if i in self.use_timesteps:
89
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
90
+ last_alpha_cumprod = alpha_cumprod
91
+ self.timestep_map.append(i)
92
+ kwargs["betas"] = np.array(new_betas)
93
+ super().__init__(**kwargs)
94
+
95
+ def p_mean_variance(
96
+ self, model, *args, **kwargs
97
+ ): # pylint: disable=signature-differs
98
+ return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
99
+
100
+ def training_losses(
101
+ self, model, *args, **kwargs
102
+ ): # pylint: disable=signature-differs
103
+ return super().training_losses(self._wrap_model(model), *args, **kwargs)
104
+
105
+ def condition_mean(self, cond_fn, *args, **kwargs):
106
+ return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
107
+
108
+ def condition_score(self, cond_fn, *args, **kwargs):
109
+ return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
110
+
111
+ def _wrap_model(self, model):
112
+ if isinstance(model, _WrappedModel):
113
+ return model
114
+ return _WrappedModel(
115
+ model, self.timestep_map, self.original_num_steps
116
+ )
117
+
118
+ def _scale_timesteps(self, t):
119
+ # Scaling is done by the wrapped model.
120
+ return t
121
+
122
+
123
+ class _WrappedModel:
124
+ def __init__(self, model, timestep_map, original_num_steps):
125
+ self.model = model
126
+ self.timestep_map = timestep_map
127
+ # self.rescale_timesteps = rescale_timesteps
128
+ self.original_num_steps = original_num_steps
129
+
130
+ def __call__(self, x_v, x_a, ts, **kwargs):
131
+ map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
132
+ new_ts = map_tensor[ts]
133
+ # if self.rescale_timesteps:
134
+ # new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
135
+ return self.model(x_v, x_a, new_ts, **kwargs)
diffusion/timestep_sampler.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+
3
+ import numpy as np
4
+ import torch as th
5
+ import torch.distributed as dist
6
+
7
+
8
+ def create_named_schedule_sampler(name, diffusion):
9
+ """
10
+ Create a ScheduleSampler from a library of pre-defined samplers.
11
+ :param name: the name of the sampler.
12
+ :param diffusion: the diffusion object to sample for.
13
+ """
14
+ if name == "uniform":
15
+ return UniformSampler(diffusion)
16
+ elif name == "loss-second-moment":
17
+ return LossSecondMomentResampler(diffusion)
18
+ else:
19
+ raise NotImplementedError(f"unknown schedule sampler: {name}")
20
+
21
+
22
+ class ScheduleSampler(ABC):
23
+ """
24
+ A distribution over timesteps in the diffusion process, intended to reduce
25
+ variance of the objective.
26
+ By default, samplers perform unbiased importance sampling, in which the
27
+ objective's mean is unchanged.
28
+ However, subclasses may override sample() to change how the resampled
29
+ terms are reweighted, allowing for actual changes in the objective.
30
+ """
31
+
32
+ @abstractmethod
33
+ def weights(self):
34
+ """
35
+ Get a numpy array of weights, one per diffusion step.
36
+ The weights needn't be normalized, but must be positive.
37
+ """
38
+
39
+ def sample(self, batch_size, device):
40
+ """
41
+ Importance-sample timesteps for a batch.
42
+ :param batch_size: the number of timesteps.
43
+ :param device: the torch device to save to.
44
+ :return: a tuple (timesteps, weights):
45
+ - timesteps: a tensor of timestep indices.
46
+ - weights: a tensor of weights to scale the resulting losses.
47
+ """
48
+ w = self.weights()
49
+ p = w / np.sum(w)
50
+ indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
51
+ indices = th.from_numpy(indices_np).long().to(device)
52
+ weights_np = 1 / (len(p) * p[indices_np])
53
+ weights = th.from_numpy(weights_np).float().to(device)
54
+ return indices, weights
55
+
56
+
57
+ class UniformSampler(ScheduleSampler):
58
+ def __init__(self, diffusion):
59
+ self.diffusion = diffusion
60
+ self._weights = np.ones([diffusion.num_timesteps])
61
+
62
+ def weights(self):
63
+ return self._weights
64
+
65
+
66
+ class LossAwareSampler(ScheduleSampler):
67
+ def update_with_local_losses(self, local_ts, local_losses):
68
+ """
69
+ Update the reweighting using losses from a model.
70
+ Call this method from each rank with a batch of timesteps and the
71
+ corresponding losses for each of those timesteps.
72
+ This method will perform synchronization to make sure all of the ranks
73
+ maintain the exact same reweighting.
74
+ :param local_ts: an integer Tensor of timesteps.
75
+ :param local_losses: a 1D Tensor of losses.
76
+ """
77
+ batch_sizes = [
78
+ th.tensor([0], dtype=th.int32, device=local_ts.device)
79
+ for _ in range(dist.get_world_size())
80
+ ]
81
+ dist.all_gather(
82
+ batch_sizes,
83
+ th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
84
+ )
85
+
86
+ # Pad all_gather batches to be the maximum batch size.
87
+ batch_sizes = [x.item() for x in batch_sizes]
88
+ max_bs = max(batch_sizes)
89
+
90
+ timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
91
+ loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
92
+ dist.all_gather(timestep_batches, local_ts)
93
+ dist.all_gather(loss_batches, local_losses)
94
+ timesteps = [
95
+ x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
96
+ ]
97
+ losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
98
+ self.update_with_all_losses(timesteps, losses)
99
+
100
+ @abstractmethod
101
+ def update_with_all_losses(self, ts, losses):
102
+ """
103
+ Update the reweighting using losses from a model.
104
+ Sub-classes should override this method to update the reweighting
105
+ using losses from the model.
106
+ This method directly updates the reweighting without synchronizing
107
+ between workers. It is called by update_with_local_losses from all
108
+ ranks with identical arguments. Thus, it should have deterministic
109
+ behavior to maintain state across workers.
110
+ :param ts: a list of int timesteps.
111
+ :param losses: a list of float losses, one per timestep.
112
+ """
113
+
114
+
115
+ class LossSecondMomentResampler(LossAwareSampler):
116
+ def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
117
+ self.diffusion = diffusion
118
+ self.history_per_term = history_per_term
119
+ self.uniform_prob = uniform_prob
120
+ self._loss_history = np.zeros(
121
+ [diffusion.num_timesteps, history_per_term], dtype=np.float64
122
+ )
123
+ self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
124
+
125
+ def weights(self):
126
+ if not self._warmed_up():
127
+ return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
128
+ weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
129
+ weights /= np.sum(weights)
130
+ weights *= 1 - self.uniform_prob
131
+ weights += self.uniform_prob / len(weights)
132
+ return weights
133
+
134
+ def update_with_all_losses(self, ts, losses):
135
+ for t, loss in zip(ts, losses):
136
+ if self._loss_counts[t] == self.history_per_term:
137
+ # Shift out the oldest loss term.
138
+ self._loss_history[t, :-1] = self._loss_history[t, 1:]
139
+ self._loss_history[t, -1] = loss
140
+ else:
141
+ self._loss_history[t, self._loss_counts[t]] = loss
142
+ self._loss_counts[t] += 1
143
+
144
+ def _warmed_up(self):
145
+ return (self._loss_counts == self.history_per_term).all()
distributed.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.distributed as dist
3
+ from torcheval.metrics import FrechetInceptionDistance
4
+
5
+ from collections import defaultdict, deque
6
+ import os
7
+ import datetime
8
+ import builtins
9
+ from logging import getLogger
10
+ import pickle
11
+ import time
12
+
13
+ logger = getLogger()
14
+
15
+ def is_dist_avail_and_initialized():
16
+ if not dist.is_available():
17
+ return False
18
+ if not dist.is_initialized():
19
+ return False
20
+ return True
21
+
22
+ def get_world_size():
23
+ if not is_dist_avail_and_initialized():
24
+ return 1
25
+ return dist.get_world_size()
26
+
27
+ def get_rank():
28
+ if not is_dist_avail_and_initialized():
29
+ return 0
30
+ return dist.get_rank()
31
+
32
+ def is_main_process():
33
+ return get_rank() == 0
34
+
35
+ def setup_for_distributed(is_master):
36
+ """
37
+ This function disables printing when not in master process
38
+ """
39
+ builtin_print = builtins.print
40
+
41
+ def print(*args, **kwargs):
42
+ force = kwargs.pop('force', False)
43
+ force = force or (get_world_size() > 8)
44
+ if is_master or force:
45
+ now = datetime.datetime.now().time()
46
+ builtin_print('[{}] '.format(now), end='') # print with time stamp
47
+ builtin_print(*args, **kwargs)
48
+
49
+ builtins.print = print
50
+
51
+ def init_distributed(port=37124, rank_and_world_size=(None, None)):
52
+ rank, world_size = rank_and_world_size
53
+ dist_url='env://'
54
+ os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', str(port))
55
+ print("Using port", os.environ['MASTER_PORT'])
56
+
57
+ if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
58
+ try:
59
+ rank = int(os.environ["RANK"])
60
+ world_size = int(os.environ["WORLD_SIZE"])
61
+ gpu = int(os.environ["LOCAL_RANK"])
62
+ except Exception:
63
+ logger.info('torchrun env vars not sets')
64
+
65
+ elif "SLURM_PROCID" in os.environ:
66
+ try:
67
+ world_size = int(os.environ['SLURM_NTASKS'])
68
+ rank = int(os.environ['SLURM_PROCID'])
69
+ gpu = rank % torch.cuda.device_count()
70
+ if 'HOSTNAME' in os.environ:
71
+ os.environ['MASTER_ADDR'] = os.environ['HOSTNAME']
72
+ else:
73
+ os.environ['MASTER_ADDR'] = '127.0.0.1'
74
+ except Exception:
75
+ logger.info('SLURM vars not set')
76
+
77
+ else:
78
+ rank = 0
79
+ world_size = 1
80
+ gpu = 0
81
+ os.environ['MASTER_ADDR'] = '127.0.0.1'
82
+
83
+ torch.cuda.set_device(gpu)
84
+
85
+ torch.distributed.init_process_group(
86
+ backend='nccl',
87
+ world_size=world_size,
88
+ rank=rank,
89
+ init_method=dist_url
90
+ )
91
+
92
+ # setup_for_distributed(rank == 0)
93
+ return world_size, rank, gpu, True
94
+
95
+
96
+ class SmoothedValue(object):
97
+ """Track a series of values and provide access to smoothed values over a
98
+ window or the global series average.
99
+ """
100
+
101
+ def __init__(self, window_size=20, fmt=None):
102
+ if fmt is None:
103
+ fmt = "{median:.4f} ({global_avg:.4f})"
104
+ self.deque = deque(maxlen=window_size)
105
+ self.total = 0.0
106
+ self.count = 0
107
+ self.fmt = fmt
108
+
109
+ def update(self, value, n=1):
110
+ self.deque.append(value)
111
+ self.count += n
112
+ self.total += value * n
113
+
114
+ def synchronize_between_processes(self):
115
+ """
116
+ Warning: does not synchronize the deque!
117
+ """
118
+ if not is_dist_avail_and_initialized():
119
+ return
120
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
121
+ dist.barrier()
122
+ dist.all_reduce(t)
123
+ t = t.tolist()
124
+ self.count = int(t[0])
125
+ self.total = t[1]
126
+
127
+ @property
128
+ def median(self):
129
+ d = torch.tensor(list(self.deque))
130
+ return d.median().item()
131
+
132
+ @property
133
+ def avg(self):
134
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
135
+ return d.mean().item()
136
+
137
+ @property
138
+ def global_avg(self):
139
+ return self.total / self.count
140
+
141
+ @property
142
+ def max(self):
143
+ return max(self.deque)
144
+
145
+ @property
146
+ def value(self):
147
+ return self.deque[-1]
148
+
149
+ def __str__(self):
150
+ return self.fmt.format(
151
+ median=self.median,
152
+ avg=self.avg,
153
+ global_avg=self.global_avg,
154
+ max=self.max,
155
+ value=self.value)
156
+
157
+ class MetricLogger(object):
158
+ def __init__(self, delimiter="\t"):
159
+ self.meters = defaultdict(SmoothedValue)
160
+ self.delimiter = delimiter
161
+
162
+ def update(self, **kwargs):
163
+ for k, v in kwargs.items():
164
+ if v is None:
165
+ continue
166
+ if isinstance(v, torch.Tensor):
167
+ v = v.item()
168
+ assert isinstance(v, (float, int))
169
+ self.meters[k].update(v)
170
+
171
+ def __getattr__(self, attr):
172
+ if attr in self.meters:
173
+ return self.meters[attr]
174
+ if attr in self.__dict__:
175
+ return self.__dict__[attr]
176
+ raise AttributeError("'{}' object has no attribute '{}'".format(
177
+ type(self).__name__, attr))
178
+
179
+ def __str__(self):
180
+ loss_str = []
181
+ for name, meter in self.meters.items():
182
+ loss_str.append(
183
+ "{}: {}".format(name, str(meter))
184
+ )
185
+ return self.delimiter.join(loss_str)
186
+
187
+ def synchronize_between_processes(self):
188
+ for meter in self.meters.values():
189
+ meter.synchronize_between_processes()
190
+
191
+ def add_meter(self, name, meter):
192
+ self.meters[name] = meter
193
+
194
+ def log_every(self, iterable, print_freq, header=None):
195
+ i = 0
196
+ if not header:
197
+ header = ''
198
+ start_time = time.time()
199
+ end = time.time()
200
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
201
+ data_time = SmoothedValue(fmt='{avg:.4f}')
202
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
203
+ log_msg = [
204
+ header,
205
+ '[{0' + space_fmt + '}/{1}]',
206
+ 'eta: {eta}',
207
+ '{meters}',
208
+ 'time: {time}',
209
+ 'data: {data}'
210
+ ]
211
+ if torch.cuda.is_available():
212
+ log_msg.append('max mem: {memory:.0f}')
213
+ log_msg = self.delimiter.join(log_msg)
214
+ MB = 1024.0 * 1024.0
215
+ for obj in iterable:
216
+ data_time.update(time.time() - end)
217
+ yield obj
218
+ iter_time.update(time.time() - end)
219
+ if i % print_freq == 0 or i == len(iterable) - 1:
220
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
221
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
222
+ if torch.cuda.is_available():
223
+ print(log_msg.format(
224
+ i, len(iterable), eta=eta_string,
225
+ meters=str(self),
226
+ time=str(iter_time), data=str(data_time),
227
+ memory=torch.cuda.max_memory_allocated() / MB))
228
+ else:
229
+ print(log_msg.format(
230
+ i, len(iterable), eta=eta_string,
231
+ meters=str(self),
232
+ time=str(iter_time), data=str(data_time)))
233
+ i += 1
234
+ end = time.time()
235
+ total_time = time.time() - start_time
236
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
237
+ print('{} Total time: {} ({:.4f} s / it)'.format(
238
+ header, total_time_str, total_time / len(iterable)))
239
+ self.update(total_time=total_time)
240
+
241
+ def sync_fid_loss_fns(fid_loss_fn, device="cuda"):
242
+ """
243
+ Synchronizes FID loss function metrics across all processes.
244
+
245
+ Args:
246
+ fid_loss_fn (dict): Local FID loss function metrics on each process.
247
+ device (str): Device to move the merged FID metrics to.
248
+
249
+ Returns:
250
+ final_fid_loss_fn (dict): Merged FID loss function metrics on all processes.
251
+ """
252
+ if not is_dist_avail_and_initialized():
253
+ return fid_loss_fn
254
+
255
+ serialized_fid_loss_fn = pickle.dumps(fid_loss_fn)
256
+ gathered_fid_loss_fn = [None] * dist.get_world_size()
257
+
258
+ dist.barrier()
259
+
260
+ dist.all_gather_object(gathered_fid_loss_fn, serialized_fid_loss_fn)
261
+
262
+ final_fid_loss_fn = {
263
+ 1: FrechetInceptionDistance(feature_dim=2048).to(device),
264
+ 2: FrechetInceptionDistance(feature_dim=2048).to(device),
265
+ 4: FrechetInceptionDistance(feature_dim=2048).to(device),
266
+ 8: FrechetInceptionDistance(feature_dim=2048).to(device),
267
+ 16: FrechetInceptionDistance(feature_dim=2048).to(device),
268
+ }
269
+
270
+ for serialized_fid_loss_fn in gathered_fid_loss_fn:
271
+ curr_fid_loss_fn = pickle.loads(serialized_fid_loss_fn)
272
+ for sec in [1, 2, 4, 8, 16]:
273
+ sec_fid_loss_fn = curr_fid_loss_fn[sec]
274
+ final_fid_loss_fn[sec].merge_state([sec_fid_loss_fn])
275
+
276
+ return final_fid_loss_fn
277
+
eval_audio.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # eval_audio.py
2
+ from typing import Optional
3
+ import os
4
+ import re
5
+ import argparse
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import torchaudio
10
+ import librosa
11
+ import matplotlib.pyplot as plt
12
+
13
+ _EPS = 1e-12
14
+
15
+ def build_mel_transform(
16
+ sample_rate,
17
+ n_fft=1024,
18
+ win_length=1024,
19
+ hop_length=256,
20
+ n_mels=80,
21
+ power=1.0,
22
+ f_min=0.0,
23
+ f_max=None,
24
+ mel_scale="htk",
25
+ norm=None,
26
+ device=None,
27
+ ):
28
+ mel_tf = torchaudio.transforms.MelSpectrogram(
29
+ sample_rate=sample_rate,
30
+ n_fft=n_fft,
31
+ win_length=win_length,
32
+ hop_length=hop_length,
33
+ f_min=f_min,
34
+ f_max=f_max,
35
+ n_mels=n_mels,
36
+ power=power,
37
+ center=True,
38
+ norm=norm,
39
+ mel_scale=mel_scale,
40
+ )
41
+ if device is not None:
42
+ mel_tf = mel_tf.to(device)
43
+ return mel_tf
44
+
45
+
46
+ def _ensure_stereo_torch(x):
47
+ if x.dim() == 1:
48
+ x = x.unsqueeze(0)
49
+ if x.size(0) == 1:
50
+ x = x.repeat(2, 1)
51
+ elif x.size(0) > 2:
52
+ x = x[:2]
53
+ return x
54
+
55
+
56
+ @torch.no_grad()
57
+ def mel_cosine_stereo(
58
+ ref, hat, sample_rate,
59
+ n_fft=1024,
60
+ win_length=1024,
61
+ hop_length=256,
62
+ n_mels=80,
63
+ power=1.0,
64
+ mel_tf=None,
65
+ ):
66
+ ref = _ensure_stereo_torch(ref)
67
+ hat = _ensure_stereo_torch(hat)
68
+
69
+ device = ref.device
70
+ if mel_tf is None:
71
+ mel_tf = build_mel_transform(
72
+ sample_rate=sample_rate,
73
+ n_fft=n_fft, win_length=win_length, hop_length=hop_length,
74
+ n_mels=n_mels, power=power, device=device
75
+ )
76
+ else:
77
+ mel_tf = mel_tf.to(device)
78
+
79
+ Mr = mel_tf(ref)
80
+ Mh = mel_tf(hat)
81
+
82
+ Ar = Mr.reshape(Mr.size(0), -1)
83
+ Ah = Mh.reshape(Mh.size(0), -1)
84
+
85
+ sim = F.cosine_similarity(Ar, Ah, dim=-1)
86
+ return float(sim.mean().item())
87
+
88
+
89
+ @torch.no_grad()
90
+ def drms_avg_db_stereo(ref, hat, win_length=1024, hop_length=256):
91
+ ref = _ensure_stereo_torch(ref)
92
+ hat = _ensure_stereo_torch(hat)
93
+
94
+ def _rms_db(x):
95
+ C, T = x.size(0), x.size(1)
96
+ if T < win_length:
97
+ x = F.pad(x, (0, win_length - T))
98
+ frames = x.unfold(dimension=-1, size=win_length, step=hop_length)
99
+ rms = torch.sqrt(frames.pow(2).mean(dim=-1) + _EPS)
100
+ db = 20.0 * torch.log10(rms + _EPS)
101
+ return db
102
+
103
+ dbr = _rms_db(ref)
104
+ dbh = _rms_db(hat)
105
+
106
+ Fmin = min(dbr.size(-1), dbh.size(-1))
107
+ dbr = dbr[:, :Fmin]
108
+ dbh = dbh[:, :Fmin]
109
+
110
+ d_db = dbh - dbr
111
+ return float(d_db.mean(dim=-1).mean().item())
112
+
113
+
114
+ def load_stereo_wav_np(path):
115
+ y, sr = librosa.load(path, sr=None, mono=False)
116
+ if y.ndim == 1:
117
+ y = np.stack([y, y], axis=0)
118
+ elif y.shape[0] != 2:
119
+ y = y[:2]
120
+ return y, sr
121
+
122
+
123
+ def compute_spectrogram_np(audio_stereo,
124
+ n_fft=512,
125
+ hop_length=160,
126
+ win_length=400,
127
+ pool=4):
128
+ def _stft_abs(sig):
129
+ st = np.abs(librosa.stft(sig, n_fft=n_fft, hop_length=hop_length, win_length=win_length))
130
+ h, w = st.shape
131
+ hq, wq = h // pool, w // pool
132
+ if hq == 0 or wq == 0:
133
+ raise ValueError(f"audio too short for pooling (stft shape {st.shape})")
134
+ st = st[:hq * pool, :wq * pool]
135
+ st = st.reshape(hq, pool, wq, pool).mean(axis=(1, 3))
136
+ return st
137
+
138
+ L = np.log1p(_stft_abs(audio_stereo[0]))
139
+ if audio_stereo.shape[0] >= 2:
140
+ R = np.log1p(_stft_abs(audio_stereo[1]))
141
+ else:
142
+ R = L.copy()
143
+ spec = np.stack([L, R], axis=-1)
144
+ return spec
145
+
146
+
147
+ def render_ref_hat_panel(title, spec_ref, spec_hat, out_path, cmap="magma"):
148
+ L_all = [spec_ref[:, :, 0], spec_hat[:, :, 0]]
149
+ R_all = [spec_ref[:, :, 1], spec_hat[:, :, 1]]
150
+
151
+ if any(a.size == 0 for a in L_all + R_all):
152
+ print(f"[SKIP]")
153
+ return False
154
+
155
+ vmin_L = min(a.min() for a in L_all)
156
+ vmax_L = max(a.max() for a in L_all)
157
+ vmin_R = min(a.min() for a in R_all)
158
+ vmax_R = max(a.max() for a in R_all)
159
+
160
+ fig, axes = plt.subplots(2, 2, figsize=(8, 6), constrained_layout=True)
161
+ Lr, Rr = spec_ref[:, :, 0], spec_ref[:, :, 1]
162
+ Lh, Rh = spec_hat[:, :, 0], spec_hat[:, :, 1]
163
+
164
+ axes[0, 0].imshow(Lr, origin="lower", aspect="auto", cmap=cmap, vmin=vmin_L, vmax=vmax_L)
165
+ axes[0, 1].imshow(Lh, origin="lower", aspect="auto", cmap=cmap, vmin=vmin_L, vmax=vmax_L)
166
+ axes[1, 0].imshow(Rr, origin="lower", aspect="auto", cmap=cmap, vmin=vmin_R, vmax=vmax_R)
167
+ axes[1, 1].imshow(Rh, origin="lower", aspect="auto", cmap=cmap, vmin=vmin_R, vmax=vmax_R)
168
+
169
+ axes[0, 0].set_title("ref")
170
+ axes[0, 1].set_title("hat")
171
+ axes[0, 0].set_ylabel("Left")
172
+ axes[1, 0].set_ylabel("Right")
173
+
174
+ for ax in axes.ravel():
175
+ ax.set_xticks([])
176
+ ax.set_yticks([])
177
+
178
+ fig.suptitle(title)
179
+ os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True)
180
+ plt.savefig(out_path, dpi=180)
181
+ plt.close(fig)
182
+ return True
183
+
184
+
185
+ def save_ref_hat_spectrogram_panel(
186
+ ref, hat, out_path,
187
+ n_fft=512,
188
+ hop_length=160,
189
+ win_length=400,
190
+ pool=4,
191
+ title="ref vs hat (binaural spectrogram)",
192
+ cmap="magma",
193
+ ):
194
+ def _to_np_stereo(x):
195
+ if isinstance(x, torch.Tensor):
196
+ x = x.detach().to(torch.float32).cpu().numpy()
197
+ if x.ndim == 1:
198
+ x = np.stack([x, x], axis=0)
199
+ elif x.shape[0] == 1:
200
+ x = np.repeat(x, 2, axis=0)
201
+ elif x.shape[0] > 2:
202
+ x = x[:2]
203
+ return x
204
+
205
+ ref_np = _to_np_stereo(ref)
206
+ hat_np = _to_np_stereo(hat)
207
+
208
+ spec_ref = compute_spectrogram_np(ref_np, n_fft=n_fft, hop_length=hop_length, win_length=win_length, pool=pool)
209
+ spec_hat = compute_spectrogram_np(hat_np, n_fft=n_fft, hop_length=hop_length, win_length=win_length, pool=pool)
210
+ return render_ref_hat_panel(title, spec_ref, spec_hat, out_path, cmap=cmap)
eval_metrics.py ADDED
@@ -0,0 +1,1033 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc.
2
+ # All rights reserved.
3
+
4
+ import os
5
+ import json
6
+ import argparse
7
+ from pathlib import Path
8
+ from tqdm import tqdm
9
+
10
+ import torch
11
+ import torch.distributed as dist_torch
12
+ import torch.nn.functional as F
13
+ import numpy as np
14
+ from PIL import Image
15
+ import lpips
16
+ from dreamsim import dreamsim
17
+ from torchvision import transforms
18
+ from torcheval.metrics import FrechetInceptionDistance
19
+ import soundfile as sf
20
+ import resampy
21
+ import distributed as dist
22
+ import librosa
23
+ from skimage.metrics import structural_similarity as sk_ssim
24
+ from mel_scale import MelScale
25
+
26
+ # -----------------------------
27
+ # Safe, lazy import for FAD (avoid argparse conflicts from dependencies)
28
+ # -----------------------------
29
+ def safe_import_fad():
30
+ """
31
+ Import frechet_audio_distance.FrechetAudioDistance without letting downstream
32
+ libraries parse our CLI args during import time.
33
+ """
34
+ import importlib, sys
35
+ argv_backup = sys.argv[:]
36
+ try:
37
+ sys.argv = [argv_backup[0]] # hide our CLI flags from misbehaving imports
38
+ fad_mod = importlib.import_module("frechet_audio_distance")
39
+ return getattr(fad_mod, "FrechetAudioDistance")
40
+ finally:
41
+ sys.argv = argv_backup
42
+
43
+
44
+ # -----------------------------
45
+ # Distributed init
46
+ # -----------------------------
47
+ def setup_distributed():
48
+ if "RANK" in os.environ and "WORLD_SIZE" in os.environ and "LOCAL_RANK" in os.environ:
49
+ rank = int(os.environ["RANK"])
50
+ world_size = int(os.environ["WORLD_SIZE"])
51
+ local_rank = int(os.environ["LOCAL_RANK"])
52
+ else:
53
+ return 0, 1, 0
54
+
55
+ os.environ.setdefault("MASTER_ADDR", "127.0.0.1")
56
+ os.environ.setdefault("MASTER_PORT", "29500")
57
+
58
+ assert torch.cuda.is_available(), "CUDA Unavailable"
59
+ assert torch.cuda.device_count() > local_rank, "local_rank out of the number of GPUs"
60
+ torch.cuda.set_device(local_rank)
61
+
62
+ dist_torch.init_process_group(
63
+ backend="nccl",
64
+ init_method="env://",
65
+ rank=rank,
66
+ world_size=world_size,
67
+ )
68
+ dist_torch.barrier()
69
+
70
+ if rank == 0:
71
+ print(f"[init] world_size={world_size} | rank->gpu OK")
72
+
73
+ return rank, world_size, local_rank
74
+
75
+
76
+ # -----------------------------
77
+ # Vision metrics factory
78
+ # -----------------------------
79
+ def get_loss_fn(loss_fn_type, secs, device):
80
+ if loss_fn_type == 'lpips':
81
+ general_lpips_loss_fn = lpips.LPIPS(net='alex').to(device).eval()
82
+
83
+ def loss_fn(img0_paths, img1_paths):
84
+ img0_list, img1_list = [], []
85
+ for p0, p1 in zip(img0_paths, img1_paths):
86
+ img0 = lpips.im2tensor(lpips.load_image(p0)).to(device) # [-1,1]
87
+ img1 = lpips.im2tensor(lpips.load_image(p1)).to(device)
88
+ img0_list.append(img0)
89
+ img1_list.append(img1)
90
+ all_img0 = torch.cat(img0_list, dim=0)
91
+ all_img1 = torch.cat(img1_list, dim=0)
92
+ with torch.no_grad():
93
+ dist_val = general_lpips_loss_fn.forward(all_img0, all_img1)
94
+ return dist_val.mean()
95
+
96
+ elif loss_fn_type == 'dreamsim':
97
+ dreamsim_loss_fn, preprocess = dreamsim(pretrained=True, device=device)
98
+ dreamsim_loss_fn.eval()
99
+
100
+ def loss_fn(img0_paths, img1_paths):
101
+ img0_list, img1_list = [], []
102
+ for p0, p1 in zip(img0_paths, img1_paths):
103
+ img0 = preprocess(Image.open(p0)).to(device)
104
+ img1 = preprocess(Image.open(p1)).to(device)
105
+ img0_list.append(img0)
106
+ img1_list.append(img1)
107
+ all_img0 = torch.cat(img0_list, dim=0)
108
+ all_img1 = torch.cat(img1_list, dim=0)
109
+ with torch.no_grad():
110
+ dist_val = dreamsim_loss_fn(all_img0, all_img1)
111
+ return dist_val.mean()
112
+
113
+ elif loss_fn_type == 'fid':
114
+ fid_metrics = {}
115
+ for sec in secs:
116
+ fid_metrics[sec] = FrechetInceptionDistance(feature_dim=2048).to(device)
117
+ return fid_metrics
118
+
119
+ else:
120
+ raise NotImplementedError
121
+
122
+ return loss_fn
123
+
124
+
125
+ # ===== Helpers for LSD/SSIM (reproducing AudioMetrics behavior) =====
126
+ _EPS = 1e-12
127
+
128
+ def _ensure_stereo_np(y: np.ndarray):
129
+ if y.ndim == 1:
130
+ y = np.stack([y, y], axis=0)
131
+ elif y.ndim == 2:
132
+ if y.shape[0] == 1:
133
+ y = np.concatenate([y, y], axis=0)
134
+ elif y.shape[0] > 2:
135
+ y = y[:2, :]
136
+ else:
137
+ raise ValueError("Unsupported audio array shape")
138
+ return y
139
+
140
+ def _wav_to_spectrogram(wav: np.ndarray, rate: int):
141
+ if rate == 44100:
142
+ hop_length = 441
143
+ n_fft = 2048
144
+ elif rate == 16000:
145
+ hop_length = 160
146
+ n_fft = 743
147
+ else:
148
+ raise ValueError("Bad Samplerate (expected 16000 or 44100)")
149
+
150
+ f = np.abs(librosa.stft(wav, hop_length=hop_length, n_fft=n_fft)) # [F, T]
151
+ f = np.transpose(f, (1, 0)) # [T, F]
152
+ f_torch = torch.tensor(f[None, None, ...], dtype=torch.float32) # [1,1,T,F]
153
+ return f_torch
154
+
155
+ def _lsd_from_specs(est: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
156
+ ratio = (target ** 2) / ((est + _EPS) ** 2) + _EPS
157
+ lsd = torch.log10(ratio) ** 2
158
+ lsd = torch.mean(torch.mean(lsd, dim=3) ** 0.5, dim=2)
159
+ return lsd.mean()
160
+
161
+ def _mel_lsd_ssim_single(
162
+ e_wav: np.ndarray,
163
+ g_wav: np.ndarray,
164
+ mel_tf: MelScale,
165
+ n_fft: int = 743,
166
+ hop_length: int = 160,
167
+ ) -> tuple[float, float]:
168
+ est_mag = np.abs(librosa.stft(e_wav, n_fft=n_fft, hop_length=hop_length))
169
+ ref_mag = np.abs(librosa.stft(g_wav, n_fft=n_fft, hop_length=hop_length))
170
+ est_mag_t = torch.from_numpy(est_mag).float()
171
+ ref_mag_t = torch.from_numpy(ref_mag).float()
172
+ est_mel = mel_tf(est_mag_t)
173
+ ref_mel = mel_tf(ref_mag_t)
174
+ ex_m = est_mel.transpose(0, 1).unsqueeze(0).unsqueeze(0)
175
+ gt_m = ref_mel.transpose(0, 1).unsqueeze(0).unsqueeze(0)
176
+ mel_lsd = float(_lsd_from_specs(ex_m, gt_m))
177
+ mel_ssim = float(_ssim_from_specs(ex_m, gt_m))
178
+ return mel_lsd, mel_ssim
179
+
180
+ def _to_log_specs(x: torch.Tensor) -> torch.Tensor:
181
+ return torch.log10(x + _EPS)
182
+
183
+ def _pow_p_norm(x: torch.Tensor) -> torch.Tensor:
184
+ return torch.mean(x.pow(2), dim=(2, 3))
185
+
186
+ def _energy_unify(est: torch.Tensor, target: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
187
+ p_est = _pow_p_norm(est)
188
+ p_tgt = _pow_p_norm(target)
189
+ scale = torch.sqrt((p_tgt + _EPS) / (p_est + _EPS))
190
+ scale = scale[..., None, None]
191
+ est_scaled = est * scale
192
+ return est_scaled, target
193
+
194
+ def _sispec_from_specs(est: torch.Tensor, target: torch.Tensor, log_domain: bool) -> torch.Tensor:
195
+ if log_domain:
196
+ est = _to_log_specs(est)
197
+ target = _to_log_specs(target)
198
+ est_u, tgt_u = _energy_unify(est, target)
199
+ noise = est_u - tgt_u
200
+ snr = ( _pow_p_norm(tgt_u) / (_pow_p_norm(noise) + _EPS) ) + _EPS
201
+ sp_loss = 10.0 * torch.log10(snr)
202
+ return sp_loss.mean()
203
+
204
+
205
+ # ===== Image PSNR (RGB on [0,1]) =====
206
+ def _psnr_from_tensors(gt: torch.Tensor, pred: torch.Tensor, data_range: float = 1.0, eps: float = 1e-10) -> torch.Tensor:
207
+ mse = torch.mean((gt - pred) ** 2, dim=(1, 2, 3))
208
+ dr = torch.as_tensor(data_range, device=gt.device, dtype=gt.dtype)
209
+ psnr = 10.0 * torch.log10((dr * dr) / (mse + eps))
210
+ return psnr
211
+
212
+ def _ssim_from_specs(est: torch.Tensor, target: torch.Tensor) -> float:
213
+ if est.is_cuda:
214
+ est_np = est.detach().cpu().numpy()
215
+ tgt_np = target.detach().cpu().numpy()
216
+ else:
217
+ est_np = est.numpy()
218
+ tgt_np = target.numpy()
219
+
220
+ N, C, _, _ = est_np.shape
221
+ acc, cnt = 0.0, 0
222
+ for n in range(N):
223
+ for c in range(C):
224
+ ref = tgt_np[n, c, ...]
225
+ out = est_np[n, c, ...]
226
+ rng = float(out.max() - out.min())
227
+ rng = 1.0 if rng == 0.0 else rng
228
+ s = sk_ssim(out, ref, win_size=7, data_range=rng)
229
+ acc += float(s); cnt += 1
230
+ return acc / max(cnt, 1)
231
+
232
+
233
+ # ==========================================================
234
+ # Streaming, DDP-friendly Audio FAD
235
+ # (embeddings identical to official FrechetAudioDistance)
236
+ # ==========================================================
237
+ class _RunningGaussianStats:
238
+ def __init__(self, feat_dim: int, device: torch.device):
239
+ self.D = feat_dim
240
+ self.device = device
241
+ self.reset()
242
+
243
+ def reset(self):
244
+ D = self.D
245
+ self.count = torch.zeros(1, device=self.device, dtype=torch.float64)
246
+ self.sum_feat = torch.zeros(D, device=self.device, dtype=torch.float64)
247
+ self.sum_outer = torch.zeros(D, D, device=self.device, dtype=torch.float64)
248
+
249
+ @torch.no_grad()
250
+ def update(self, feats: torch.Tensor): # [N, D]
251
+ if feats is None or feats.numel() == 0:
252
+ return
253
+ f = feats.to(dtype=torch.float64)
254
+ self.count += torch.tensor([f.shape[0]], device=self.device, dtype=torch.float64)
255
+ self.sum_feat += f.sum(dim=0)
256
+ self.sum_outer += f.t().mm(f)
257
+
258
+ @torch.no_grad()
259
+ def sync(self):
260
+ if dist_torch.is_initialized():
261
+ for t in (self.count, self.sum_feat, self.sum_outer):
262
+ dist_torch.all_reduce(t, op=dist_torch.ReduceOp.SUM)
263
+
264
+ @torch.no_grad()
265
+ def mean_cov(self, eps: float = 1e-6):
266
+ n = int(self.count.item())
267
+ if n == 0:
268
+ return None, None
269
+ mean = self.sum_feat / self.count
270
+ cov = self.sum_outer / self.count - torch.ger(mean, mean)
271
+ cov = cov + torch.eye(self.D, device=self.device, dtype=torch.float64) * eps
272
+ return mean, cov
273
+
274
+
275
+ @torch.no_grad()
276
+ def _frechet_distance_torch(mean1, cov1, mean2, cov2) -> float:
277
+ diff = mean1 - mean2
278
+ diff2 = diff.dot(diff)
279
+ evals1, evecs1 = torch.linalg.eigh(cov1)
280
+ sqrt1 = evecs1 @ torch.diag(evals1.clamp(min=0).sqrt()) @ evecs1.t()
281
+ prod = sqrt1 @ cov2 @ sqrt1
282
+ evals_prod = torch.linalg.eigvalsh(prod).clamp(min=0).sqrt()
283
+ trace = torch.trace(cov1 + cov2) - 2.0 * evals_prod.sum()
284
+ return float((diff2 + trace).item())
285
+
286
+
287
+ class StreamingFAD:
288
+ """
289
+ Mono (downmix) FID-style streaming FAD:
290
+ - update_from_wavs(paths, is_real=True/False)
291
+ - compute() # does DDP all_reduce internally
292
+ """
293
+ def __init__(self, fad_backend, pad_seconds: float = 0.96, batch_size: int = 16):
294
+ self.fad = fad_backend
295
+ self.device = self.fad.device
296
+ self.bs = batch_size
297
+ self.pad_len = int(round(self.fad.sample_rate * float(pad_seconds)))
298
+ self.feat_dim = self._infer_feat_dim()
299
+ self.real_stats = _RunningGaussianStats(self.feat_dim, self.device)
300
+ self.fake_stats = _RunningGaussianStats(self.feat_dim, self.device)
301
+
302
+ def _infer_feat_dim(self) -> int:
303
+ sr = self.fad.sample_rate
304
+ x = np.zeros((self.pad_len,), dtype=np.float32)
305
+ emb = self.fad.get_embeddings([x], sr=sr)
306
+ return int(emb.shape[-1]) if isinstance(emb, np.ndarray) else int(emb.shape[-1])
307
+
308
+ @torch.no_grad()
309
+ def _load_and_resample(self, path: str):
310
+ try:
311
+ audio, sr = sf.read(path, dtype="float32", always_2d=False)
312
+ except Exception as e:
313
+ print(f"[StreamingFAD] read error: {path}: {e}")
314
+ return None
315
+ if audio is None or (isinstance(audio, np.ndarray) and audio.size == 0):
316
+ return None
317
+ if isinstance(audio, np.ndarray) and audio.ndim == 2:
318
+ audio = audio.mean(axis=1)
319
+ if sr != self.fad.sample_rate:
320
+ try:
321
+ audio = resampy.resample(audio, sr, self.fad.sample_rate)
322
+ except Exception as e:
323
+ print(f"[StreamingFAD] resample error: {path}: {e}")
324
+ return None
325
+ if audio.shape[0] < self.pad_len:
326
+ pad = np.zeros((self.pad_len - audio.shape[0],), dtype=np.float32)
327
+ audio = np.concatenate([audio, pad], axis=0)
328
+ return audio.astype(np.float32, copy=False)
329
+
330
+ @torch.no_grad()
331
+ def update_from_wavs(self, wav_paths, is_real: bool):
332
+ if not wav_paths:
333
+ return
334
+ xs = []
335
+ for p in wav_paths:
336
+ a = self._load_and_resample(p)
337
+ if a is not None:
338
+ xs.append(a)
339
+ if not xs:
340
+ return
341
+ feats_chunks = []
342
+ for i in range(0, len(xs), self.bs):
343
+ chunk = xs[i:i+self.bs]
344
+ emb_np = self.fad.get_embeddings(chunk, sr=self.fad.sample_rate)
345
+ if isinstance(emb_np, np.ndarray):
346
+ if emb_np.size == 0:
347
+ continue
348
+ feats_chunks.append(torch.from_numpy(emb_np).to(self.device))
349
+ else:
350
+ if emb_np.numel() == 0:
351
+ continue
352
+ feats_chunks.append(emb_np.to(self.device))
353
+ if len(feats_chunks) == 0:
354
+ return
355
+ feats = torch.cat(feats_chunks, dim=0)
356
+ (self.real_stats if is_real else self.fake_stats).update(feats)
357
+
358
+ @torch.no_grad()
359
+ def compute(self) -> float:
360
+ self.real_stats.sync()
361
+ self.fake_stats.sync()
362
+ m1, c1 = self.real_stats.mean_cov()
363
+ m2, c2 = self.fake_stats.mean_cov()
364
+ if (m1 is None) or (m2 is None):
365
+ raise RuntimeError("StreamingFAD: empty stats")
366
+ return _frechet_distance_torch(m1, c1, m2, c2)
367
+
368
+
369
+ class StereoStreamingFAD:
370
+ def __init__(self, fad_backend, pad_seconds: float = 0.96, batch_size: int = 16):
371
+ self.fad = fad_backend
372
+ self.device = self.fad.device
373
+ self.bs = batch_size
374
+ self.pad_len = int(round(self.fad.sample_rate * float(pad_seconds)))
375
+
376
+ self.feat_dim = self._infer_feat_dim()
377
+ self.L_real = _RunningGaussianStats(self.feat_dim, self.device)
378
+ self.L_fake = _RunningGaussianStats(self.feat_dim, self.device)
379
+ self.R_real = _RunningGaussianStats(self.feat_dim, self.device)
380
+ self.R_fake = _RunningGaussianStats(self.feat_dim, self.device)
381
+
382
+ def _infer_feat_dim(self) -> int:
383
+ sr = self.fad.sample_rate
384
+ x = np.zeros((self.pad_len,), dtype=np.float32)
385
+ emb = self.fad.get_embeddings([x], sr=sr)
386
+ return int(emb.shape[-1]) if isinstance(emb, np.ndarray) else int(emb.shape[-1])
387
+
388
+ @torch.no_grad()
389
+ def _load_lr_and_resample_pad(self, path: str):
390
+ try:
391
+ audio, sr = sf.read(path, dtype="float32", always_2d=True) # [T, C]
392
+ except Exception as e:
393
+ print(f"[StereoFAD] read error: {path}: {e}")
394
+ return None, None
395
+ if audio is None or audio.size == 0:
396
+ return None, None
397
+
398
+ C = audio.shape[1]
399
+ if C == 1:
400
+ L = audio[:, 0]; R = audio[:, 0]
401
+ else:
402
+ L = audio[:, 0]; R = audio[:, 1] if C >= 2 else audio[:, 0]
403
+
404
+ if sr != self.fad.sample_rate:
405
+ try:
406
+ L = resampy.resample(L, sr, self.fad.sample_rate)
407
+ R = resampy.resample(R, sr, self.fad.sample_rate)
408
+ except Exception as e:
409
+ print(f"[StereoFAD] resample error: {path}: {e}")
410
+ return None, None
411
+
412
+ def _pad_to_len(x: np.ndarray, n: int):
413
+ if x.shape[0] >= n:
414
+ return x.astype(np.float32, copy=False)
415
+ pad = np.zeros((n - x.shape[0],), dtype=np.float32)
416
+ return np.concatenate([x, pad], axis=0)
417
+
418
+ L = _pad_to_len(L, self.pad_len)
419
+ R = _pad_to_len(R, self.pad_len)
420
+ return L, R
421
+
422
+ @torch.no_grad()
423
+ def update_from_wavs(self, wav_paths, is_real: bool):
424
+ if not wav_paths:
425
+ return
426
+ L_list, R_list = [], []
427
+ for p in wav_paths:
428
+ L, R = self._load_lr_and_resample_pad(p)
429
+ if L is not None and R is not None:
430
+ L_list.append(L); R_list.append(R)
431
+ if not L_list:
432
+ return
433
+
434
+ def _embed_and_update(xs, stats_obj: _RunningGaussianStats):
435
+ feats_chunks = []
436
+ for i in range(0, len(xs), self.bs):
437
+ chunk = xs[i:i+self.bs]
438
+ emb_np = self.fad.get_embeddings(chunk, sr=self.fad.sample_rate)
439
+ if isinstance(emb_np, np.ndarray):
440
+ if emb_np.size == 0:
441
+ continue
442
+ feats_chunks.append(torch.from_numpy(emb_np).to(self.device))
443
+ else:
444
+ if emb_np.numel() == 0:
445
+ continue
446
+ feats_chunks.append(emb_np.to(self.device))
447
+ if len(feats_chunks) == 0:
448
+ return
449
+ feats = torch.cat(feats_chunks, dim=0)
450
+ stats_obj.update(feats)
451
+
452
+ if is_real:
453
+ _embed_and_update(L_list, self.L_real)
454
+ _embed_and_update(R_list, self.R_real)
455
+ else:
456
+ _embed_and_update(L_list, self.L_fake)
457
+ _embed_and_update(R_list, self.R_fake)
458
+
459
+ @torch.no_grad()
460
+ def compute(self):
461
+ for t in (self.L_real, self.L_fake, self.R_real, self.R_fake):
462
+ t.sync()
463
+ mL_r, cL_r = self.L_real.mean_cov()
464
+ mL_f, cL_f = self.L_fake.mean_cov()
465
+ mR_r, cR_r = self.R_real.mean_cov()
466
+ mR_f, cR_f = self.R_fake.mean_cov()
467
+ if (mL_r is None) or (mL_f is None) or (mR_r is None) or (mR_f is None):
468
+ raise RuntimeError("StereoStreamingFAD: empty stats")
469
+
470
+ fad_left = _frechet_distance_torch(mL_r, cL_r, mL_f, cL_f)
471
+ fad_right = _frechet_distance_torch(mR_r, cR_r, mR_f, cR_f)
472
+ fad_mean = 0.5 * (fad_left + fad_right)
473
+ return float(fad_left), float(fad_right), float(fad_mean)
474
+
475
+
476
+ # -----------------------------
477
+ # Stereo-friendly Audio Metrics (LSD/SSIM/MelCos/DRMS)
478
+ # -----------------------------
479
+ def _load_librosa_stereo(path: str, sr: int) -> np.ndarray:
480
+ y, _ = librosa.load(path, sr=sr, mono=False)
481
+ y = _ensure_stereo_np(y) # (2, T)
482
+ return y
483
+
484
+ def _mel_cosine_single_channel(wav: np.ndarray, ref: np.ndarray, sr: int, mel_tf: MelScale) -> float:
485
+ hop_length = 160; n_fft = 743
486
+ est_mag = np.abs(librosa.stft(wav, hop_length=hop_length, n_fft=n_fft)) # [F, T]
487
+ ref_mag = np.abs(librosa.stft(ref, hop_length=hop_length, n_fft=n_fft))
488
+
489
+ est_mag_t = torch.tensor(est_mag, dtype=torch.float32) # [F,T]
490
+ ref_mag_t = torch.tensor(ref_mag, dtype=torch.float32) # [F,T]
491
+
492
+ est_mel = mel_tf(est_mag_t) # [80, T]
493
+ ref_mel = mel_tf(ref_mag_t) # [80, T]
494
+
495
+ sim = F.cosine_similarity(est_mel.flatten(), ref_mel.flatten(), dim=0)
496
+ return float(sim.item())
497
+
498
+ # -----------------------------
499
+ # Evaluate
500
+ # -----------------------------
501
+ def evaluate(args, dataset_name, eval_type, metric_logger, loss_fns,
502
+ gt_dir, exp_dir, secs, device, rank, world_size, modals):
503
+
504
+ lpips_loss_fn, dreamsim_loss_fn, fid_loss_fn = loss_fns
505
+
506
+ if eval_type == 'rollout':
507
+ eval_name = 'rollout'
508
+ image_idxs = secs.copy()
509
+ elif eval_type == 'time':
510
+ eval_name = eval_type
511
+ image_idxs = secs.copy()
512
+ else:
513
+ raise ValueError(f"Unknown eval_type {eval_type}")
514
+
515
+ if 'v' in modals:
516
+ for s in secs:
517
+ metric_logger.meters[f'{dataset_name}_{eval_name}_fid_{int(s)}'].update(0.0, n=0)
518
+
519
+ # Episodes split by rank
520
+ all_eps = sorted([e for e in os.listdir(gt_dir) if os.path.isdir(os.path.join(gt_dir, e))])
521
+ eps = all_eps[rank::world_size]
522
+ if len(eps) == 0:
523
+ return
524
+
525
+ to_tensor = transforms.ToTensor()
526
+
527
+ fad_streams = {}
528
+ stereo_mode = False
529
+ if 'a' in modals:
530
+ try:
531
+ FADLib = safe_import_fad()
532
+ except Exception as e:
533
+ if rank == 0:
534
+ print(f"[WARN] Fail to import frechet_audio_distance:{e}")
535
+ FADLib = None
536
+
537
+ if FADLib is not None:
538
+ base_fad = FADLib(
539
+ model_name=args.fad_model,
540
+ sample_rate=args.fad_sr,
541
+ verbose=False
542
+ )
543
+ if args.fad_model == 'vggish' and not args.mono:
544
+ stereo_mode = True
545
+ for sec in secs:
546
+ fad_streams[sec] = StereoStreamingFAD(base_fad, pad_seconds=args.fad_pad_sec, batch_size=16)
547
+ else:
548
+ for sec in secs:
549
+ fad_streams[sec] = StreamingFAD(base_fad, pad_seconds=args.fad_pad_sec, batch_size=16)
550
+
551
+ mel_tf = MelScale(n_mels=80, sample_rate=16000, n_stft=372)
552
+
553
+ for batch_start in tqdm(range(0, len(eps), args.batch_size),
554
+ total=(len(eps) + args.batch_size - 1) // args.batch_size,
555
+ disable=(rank != 0)):
556
+ batch_eps = eps[batch_start:batch_start + args.batch_size]
557
+
558
+ # per-sec containers (vision)
559
+ gt_img_batch, exp_img_batch = {}, {}
560
+ gt_img_paths_batch, exp_img_paths_batch = {}, {}
561
+ denorm_pairs_by_sec = {}
562
+ secs_py = [int(s) for s in secs]
563
+ denorm_pairs_by_sec = {s: [] for s in secs_py}
564
+ for sec in secs:
565
+ gt_img_batch[sec], exp_img_batch[sec] = [], []
566
+ gt_img_paths_batch[sec], exp_img_paths_batch[sec] = [], []
567
+
568
+ # per-sec containers (audio paths)
569
+ gt_wav_paths_batch, exp_wav_paths_batch = {}, {}
570
+ for sec in secs:
571
+ gt_wav_paths_batch[sec], exp_wav_paths_batch[sec] = [], []
572
+
573
+ for ep in batch_eps:
574
+ gt_ep_dir = os.path.join(gt_dir, ep)
575
+ exp_ep_dir = os.path.join(exp_dir, ep)
576
+
577
+ if (not os.path.isdir(gt_ep_dir)) or (not os.path.isdir(exp_ep_dir)):
578
+ continue
579
+
580
+ gt_dist_p = os.path.join(gt_ep_dir, "distance.json")
581
+ exp_dist_p = os.path.join(exp_ep_dir, "distance.json")
582
+ try:
583
+ if os.path.isfile(gt_dist_p) and os.path.isfile(exp_dist_p):
584
+ with open(gt_dist_p, "r") as f: gt_list = json.load(f)
585
+ with open(exp_dist_p, "r") as f: exp_list = json.load(f)
586
+ gt_map = {int(it["sec"]): float(it["denorm_gt"]) for it in gt_list if "sec" in it and "denorm_gt" in it}
587
+ exp_map = {int(it["sec"]): float(it["denorm_pred"]) for it in exp_list if "sec" in it and "denorm_pred" in it}
588
+ for s in secs_py:
589
+ if s in gt_map and s in exp_map:
590
+ denorm_pairs_by_sec[s].append((gt_map[s], exp_map[s]))
591
+ except Exception:
592
+ pass
593
+
594
+
595
+ for sec, image_idx in zip(secs, image_idxs):
596
+ # ---- vision
597
+ if 'v' in modals:
598
+ gt_sec_img_path = os.path.join(gt_ep_dir, f'{int(image_idx)}.png')
599
+ exp_sec_img_path = os.path.join(exp_ep_dir, f'{int(image_idx)}.png')
600
+ if os.path.isfile(gt_sec_img_path) and os.path.isfile(exp_sec_img_path):
601
+ try:
602
+ gt_img = to_tensor(Image.open(gt_sec_img_path).convert("RGB")).unsqueeze(0).to(device)
603
+ exp_img = to_tensor(Image.open(exp_sec_img_path).convert("RGB")).unsqueeze(0).to(device)
604
+ if torch.isfinite(gt_img).all() and torch.isfinite(exp_img).all():
605
+ gt_img_batch[sec].append(gt_img)
606
+ exp_img_batch[sec].append(exp_img)
607
+ gt_img_paths_batch[sec].append(gt_sec_img_path)
608
+ exp_img_paths_batch[sec].append(exp_sec_img_path)
609
+ except Exception:
610
+ pass
611
+
612
+ # ---- audio
613
+ if 'a' in modals:
614
+ gt_sec_wav_path = os.path.join(gt_ep_dir, f'{int(image_idx)}.wav')
615
+ exp_sec_wav_path = os.path.join(exp_ep_dir, f'{int(image_idx)}.wav')
616
+ if os.path.isfile(gt_sec_wav_path) and os.path.isfile(exp_sec_wav_path):
617
+ gt_wav_paths_batch[sec].append(gt_sec_wav_path)
618
+ exp_wav_paths_batch[sec].append(exp_sec_wav_path)
619
+
620
+ # ---- vision metric update per batch
621
+ if 'v' in modals:
622
+ for sec in secs:
623
+ if (len(gt_img_batch[sec]) == 0) or (len(exp_img_batch[sec]) == 0):
624
+ continue
625
+ lpips_dists = lpips_loss_fn(gt_img_paths_batch[sec], exp_img_paths_batch[sec])
626
+ dreamsim_dists = dreamsim_loss_fn(gt_img_paths_batch[sec], exp_img_paths_batch[sec])
627
+ metric_logger.meters[f'{dataset_name}_{eval_name}_lpips_{sec}'].update(lpips_dists, n=1)
628
+ metric_logger.meters[f'{dataset_name}_{eval_name}_dreamsim_{sec}'].update(dreamsim_dists, n=1)
629
+
630
+ sec_gt_batch = torch.cat(gt_img_batch[sec], dim=0)
631
+ sec_exp_batch = torch.cat(exp_img_batch[sec], dim=0)
632
+ if torch.isfinite(sec_gt_batch).all() and torch.isfinite(sec_exp_batch).all():
633
+ fid_loss_fn[sec].update(images=sec_gt_batch, is_real=True)
634
+ fid_loss_fn[sec].update(images=sec_exp_batch, is_real=False)
635
+ psnr_vals = _psnr_from_tensors(sec_gt_batch, sec_exp_batch, data_range=1.0) # (N,)
636
+ metric_logger.meters[f'{dataset_name}_{eval_name}_psnr_{sec}'].update(psnr_vals.mean(), n=1)
637
+
638
+ # ---- audio metrics per batch
639
+ if 'a' in modals:
640
+ # FAD (streaming)
641
+ if len(fad_streams) > 0:
642
+ for sec in secs:
643
+ if len(gt_wav_paths_batch[sec]) == 0 and len(exp_wav_paths_batch[sec]) == 0:
644
+ continue
645
+ fad_streams[sec].update_from_wavs(gt_wav_paths_batch[sec], is_real=True)
646
+ fad_streams[sec].update_from_wavs(exp_wav_paths_batch[sec], is_real=False)
647
+
648
+ # LSD / SSIM / MelCos / dRMS-db
649
+ _AUDIO_SR = 16000
650
+ for sec in secs:
651
+ gt_list = gt_wav_paths_batch[sec]
652
+ exp_list = exp_wav_paths_batch[sec]
653
+ if len(gt_list) == 0 or len(exp_list) == 0:
654
+ continue
655
+ pair_cnt = min(len(gt_list), len(exp_list))
656
+ if pair_cnt == 0:
657
+ continue
658
+
659
+ lsd_L, lsd_R, ssim_L, ssim_R = [], [], [], []
660
+ mel_L, mel_R = [], []
661
+
662
+ mel_lsd_L, mel_lsd_R = [], []
663
+ mel_ssim_L, mel_ssim_R = [], []
664
+
665
+ sispec_nl_L, sispec_nl_R = [], []
666
+ sispec_log_L, sispec_log_R = [], []
667
+ mel_sispec_nl_L, mel_sispec_n_R = [], []
668
+ mel_sispec_log_L, mel_sispec_log_R = [], []
669
+
670
+
671
+ for i in range(pair_cnt):
672
+ gpath = gt_list[i]
673
+ epath = exp_list[i]
674
+ try:
675
+ g_st = _load_librosa_stereo(gpath, _AUDIO_SR) # (2,T)
676
+ e_st = _load_librosa_stereo(epath, _AUDIO_SR) # (2,T)
677
+
678
+ if args.mono:
679
+ g_mono = g_st.mean(axis=0)
680
+ e_mono = e_st.mean(axis=0)
681
+
682
+ # LSD/SSIM
683
+ gt_sp = _wav_to_spectrogram(g_mono, rate=_AUDIO_SR)
684
+ ex_sp = _wav_to_spectrogram(e_mono, rate=_AUDIO_SR)
685
+ lsd_val = _lsd_from_specs(ex_sp.clone(), gt_sp.clone())
686
+ ssim_val = _ssim_from_specs(ex_sp.clone(), gt_sp.clone())
687
+
688
+ # MelCos
689
+ mel_val = _mel_cosine_single_channel(e_mono, g_mono, _AUDIO_SR, mel_tf)
690
+
691
+ # mel_lsd & mel_ssim
692
+ mel_lsd_val, mel_ssim_val = _mel_lsd_ssim_single(e_mono, g_mono, mel_tf)
693
+
694
+ # sispec
695
+ sispec_nl = _sispec_from_specs(ex_sp.clone(), gt_sp.clone(), log_domain=False)
696
+ sispec_log = _sispec_from_specs(ex_sp.clone(), gt_sp.clone(), log_domain=True)
697
+ # Mel sispec
698
+ mel_sispec_nl = _sispec_from_specs(ex_m.clone(), gt_m.clone(), log_domain=False)
699
+ mel_sispec_log = _sispec_from_specs(ex_m.clone(), gt_m.clone(), log_domain=True)
700
+
701
+ metric_logger.meters[f'{dataset_name}_{eval_name}_lsd_{sec}'].update(lsd_val, n=1)
702
+ metric_logger.meters[f'{dataset_name}_{eval_name}_ssim_{sec}'].update(
703
+ torch.tensor(ssim_val), n=1
704
+ )
705
+ metric_logger.meters[f'{dataset_name}_{eval_name}_melcos_{sec}'].update(
706
+ torch.tensor(mel_val), n=1
707
+ )
708
+
709
+ metric_logger.meters[f'{dataset_name}_{eval_name}_final_mel_lsd_{sec}'].update(
710
+ torch.tensor(float(mel_lsd_val)), n=1
711
+ )
712
+ metric_logger.meters[f'{dataset_name}_{eval_name}_final_mel_ssim_{sec}'].update(
713
+ torch.tensor(float(mel_ssim_val)), n=1
714
+ )
715
+
716
+ metric_logger.meters[f'{dataset_name}_{eval_name}_non_log_sispec_{sec}'].update(
717
+ torch.tensor(float(sispec_nl)), n=1
718
+ )
719
+ metric_logger.meters[f'{dataset_name}_{eval_name}_sispec_{sec}'].update(
720
+ torch.tensor(float(sispec_log)), n=1
721
+ )
722
+ metric_logger.meters[f'{dataset_name}_{eval_name}_final_non_log_mel_sispec_{sec}'].update(
723
+ torch.tensor(float(mel_sispec_nl)), n=1
724
+ )
725
+ metric_logger.meters[f'{dataset_name}_{eval_name}_final_mel_sispec_{sec}'].update(
726
+ torch.tensor(float(mel_sispec_log)), n=1
727
+ )
728
+
729
+
730
+ else:
731
+ for ch, (acc_lsd, acc_ssim, acc_mel,
732
+ acc_mel_lsd, acc_mel_ssim,
733
+ acc_sispec_nl, acc_sispec_log,
734
+ acc_mel_sispec_nl, acc_mel_sispec_log) in enumerate([
735
+ (lsd_L, ssim_L, mel_L, mel_lsd_L, mel_ssim_L, sispec_nl_L, sispec_log_L, mel_sispec_nl_L, mel_sispec_log_L),
736
+ (lsd_R, ssim_R, mel_R, mel_lsd_R, mel_ssim_R, sispec_nl_R, sispec_log_R, mel_sispec_n_R, mel_sispec_log_R),
737
+ ]):
738
+ g = g_st[ch]; e = e_st[ch]
739
+ # LSD/SSIM
740
+ gt_sp = _wav_to_spectrogram(g, rate=_AUDIO_SR)
741
+ ex_sp = _wav_to_spectrogram(e, rate=_AUDIO_SR)
742
+ acc_lsd.append(float(_lsd_from_specs(ex_sp.clone(), gt_sp.clone())))
743
+ acc_ssim.append(float(_ssim_from_specs(ex_sp.clone(), gt_sp.clone())))
744
+ # MelCos
745
+ acc_mel.append(_mel_cosine_single_channel(e, g, _AUDIO_SR, mel_tf))
746
+
747
+ # mel_lsd & mel_ssim
748
+ mel_lsd_val, mel_ssim_val = _mel_lsd_ssim_single(e, g, mel_tf)
749
+ acc_mel_lsd.append(mel_lsd_val)
750
+ acc_mel_ssim.append(mel_ssim_val)
751
+
752
+ # sispec
753
+ acc_sispec_nl.append( float(_sispec_from_specs(ex_sp.clone(), gt_sp.clone(), log_domain=False)) )
754
+ acc_sispec_log.append( float(_sispec_from_specs(ex_sp.clone(), gt_sp.clone(), log_domain=True)) )
755
+ # Mel
756
+ est_mag = np.abs(librosa.stft(e, n_fft=743, hop_length=160))
757
+ ref_mag = np.abs(librosa.stft(g, n_fft=743, hop_length=160))
758
+ est_mel = mel_tf(torch.from_numpy(est_mag).float()) # [M,T]
759
+ ref_mel = mel_tf(torch.from_numpy(ref_mag).float()) # [M,T]
760
+ ex_m = est_mel.T.unsqueeze(0).unsqueeze(0) # [1,1,T,M]
761
+ gt_m = ref_mel.T.unsqueeze(0).unsqueeze(0) # [1,1,T,M]
762
+ # sispec(Mel, non_log / log)
763
+ acc_mel_sispec_nl.append( float(_sispec_from_specs(ex_m.clone(), gt_m.clone(), log_domain=False)) )
764
+ acc_mel_sispec_log.append( float(_sispec_from_specs(ex_m.clone(), gt_m.clone(), log_domain=True)) )
765
+
766
+ except Exception:
767
+ pass
768
+
769
+ if not args.mono:
770
+ def _maybe_mean(x):
771
+ return float(np.mean(x)) if len(x) > 0 else None
772
+
773
+ v = _maybe_mean(lsd_L); w = _maybe_mean(lsd_R)
774
+ if v is not None: metric_logger.meters[f'{dataset_name}_{eval_name}_lsdL_{sec}'].update(torch.tensor(v), n=1)
775
+ if w is not None: metric_logger.meters[f'{dataset_name}_{eval_name}_lsdR_{sec}'].update(torch.tensor(w), n=1)
776
+ if v is not None and w is not None:
777
+ metric_logger.meters[f'{dataset_name}_{eval_name}_lsd_{sec}'].update(torch.tensor(0.5*(v+w)), n=1)
778
+
779
+ v = _maybe_mean(ssim_L); w = _maybe_mean(ssim_R)
780
+ if v is not None: metric_logger.meters[f'{dataset_name}_{eval_name}_ssimL_{sec}'].update(torch.tensor(v), n=1)
781
+ if w is not None: metric_logger.meters[f'{dataset_name}_{eval_name}_ssimR_{sec}'].update(torch.tensor(w), n=1)
782
+ if v is not None and w is not None:
783
+ metric_logger.meters[f'{dataset_name}_{eval_name}_ssim_{sec}'].update(torch.tensor(0.5*(v+w)), n=1)
784
+
785
+ v = _maybe_mean(mel_L); w = _maybe_mean(mel_R)
786
+ if v is not None: metric_logger.meters[f'{dataset_name}_{eval_name}_melcosL_{sec}'].update(torch.tensor(v), n=1)
787
+ if w is not None: metric_logger.meters[f'{dataset_name}_{eval_name}_melcosR_{sec}'].update(torch.tensor(w), n=1)
788
+ if v is not None and w is not None:
789
+ metric_logger.meters[f'{dataset_name}_{eval_name}_melcos_{sec}'].update(torch.tensor(0.5*(v+w)), n=1)
790
+
791
+ v = _maybe_mean(mel_lsd_L); w = _maybe_mean(mel_lsd_R)
792
+ if v is not None: metric_logger.meters[f'{dataset_name}_{eval_name}_final_mel_lsdL_{sec}'].update(torch.tensor(v), n=1)
793
+ if w is not None: metric_logger.meters[f'{dataset_name}_{eval_name}_final_mel_lsdR_{sec}'].update(torch.tensor(w), n=1)
794
+ if v is not None and w is not None:
795
+ metric_logger.meters[f'{dataset_name}_{eval_name}_final_mel_lsd_{sec}'].update(torch.tensor(0.5*(v+w)), n=1)
796
+
797
+ v = _maybe_mean(mel_ssim_L); w = _maybe_mean(mel_ssim_R)
798
+ if v is not None: metric_logger.meters[f'{dataset_name}_{eval_name}_final_mel_ssimL_{sec}'].update(torch.tensor(v), n=1)
799
+ if w is not None: metric_logger.meters[f'{dataset_name}_{eval_name}_final_mel_ssimR_{sec}'].update(torch.tensor(w), n=1)
800
+ if v is not None and w is not None:
801
+ metric_logger.meters[f'{dataset_name}_{eval_name}_final_mel_ssim_{sec}'].update(torch.tensor(0.5*(v+w)), n=1)
802
+
803
+ v = _maybe_mean(sispec_nl_L); w = _maybe_mean(sispec_nl_R)
804
+ if v is not None: metric_logger.meters[f'{dataset_name}_{eval_name}_non_log_sispecL_{sec}'].update(torch.tensor(v), n=1)
805
+ if w is not None: metric_logger.meters[f'{dataset_name}_{eval_name}_non_log_sispecR_{sec}'].update(torch.tensor(w), n=1)
806
+ if v is not None and w is not None:
807
+ metric_logger.meters[f'{dataset_name}_{eval_name}_non_log_sispec_{sec}'].update(torch.tensor(0.5*(v+w)), n=1)
808
+
809
+ v = _maybe_mean(sispec_log_L); w = _maybe_mean(sispec_log_R)
810
+ if v is not None: metric_logger.meters[f'{dataset_name}_{eval_name}_sispecL_{sec}'].update(torch.tensor(v), n=1)
811
+ if w is not None: metric_logger.meters[f'{dataset_name}_{eval_name}_sispecR_{sec}'].update(torch.tensor(w), n=1)
812
+ if v is not None and w is not None:
813
+ metric_logger.meters[f'{dataset_name}_{eval_name}_sispec_{sec}'].update(torch.tensor(0.5*(v+w)), n=1)
814
+
815
+ v = _maybe_mean(mel_sispec_nl_L); w = _maybe_mean(mel_sispec_n_R)
816
+ if v is not None: metric_logger.meters[f'{dataset_name}_{eval_name}_final_non_log_mel_sispecL_{sec}'].update(torch.tensor(v), n=1)
817
+ if w is not None: metric_logger.meters[f'{dataset_name}_{eval_name}_final_non_log_mel_sispecR_{sec}'].update(torch.tensor(w), n=1)
818
+ if v is not None and w is not None:
819
+ metric_logger.meters[f'{dataset_name}_{eval_name}_final_non_log_mel_sispec_{sec}'].update(torch.tensor(0.5*(v+w)), n=1)
820
+
821
+ v = _maybe_mean(mel_sispec_log_L); w = _maybe_mean(mel_sispec_log_R)
822
+ if v is not None: metric_logger.meters[f'{dataset_name}_{eval_name}_final_mel_sispecL_{sec}'].update(torch.tensor(v), n=1)
823
+ if w is not None: metric_logger.meters[f'{dataset_name}_{eval_name}_final_mel_sispecR_{sec}'].update(torch.tensor(w), n=1)
824
+ if v is not None and w is not None:
825
+ metric_logger.meters[f'{dataset_name}_{eval_name}_final_mel_sispec_{sec}'].update(torch.tensor(0.5*(v+w)), n=1)
826
+ for s in secs_py:
827
+ pairs = denorm_pairs_by_sec[s]
828
+ if not pairs:
829
+ continue
830
+ arr = np.asarray(pairs, dtype=np.float32)
831
+ mask = np.isfinite(arr).all(axis=1)
832
+ if not np.any(mask):
833
+ continue
834
+ se_mean = float(np.mean((arr[mask, 1] - arr[mask, 0]) ** 2))
835
+ metric_logger.meters[f'{dataset_name}_{eval_name}_denorm_mse_{s}'].update(
836
+ torch.tensor(se_mean), n=1
837
+ )
838
+
839
+ if 'v' in modals:
840
+ feature_dim = 2048
841
+ sec_list = [int(s) for s in secs]
842
+ tmp_dir = Path(os.path.join(args.exp_dir, ".fid_tmp"))
843
+ if dist_torch.is_initialized():
844
+ if dist_torch.get_rank() == 0:
845
+ tmp_dir.mkdir(parents=True, exist_ok=True)
846
+ dist_torch.barrier()
847
+ else:
848
+ tmp_dir.mkdir(parents=True, exist_ok=True)
849
+ if dist_torch.is_initialized():
850
+ my_rank = dist_torch.get_rank()
851
+ world_size = dist_torch.get_world_size()
852
+ else:
853
+ my_rank = 0
854
+ world_size = 1
855
+
856
+ for s in sec_list:
857
+ fid_m = fid_loss_fn[s]
858
+ state = {
859
+ "real_sum": fid_m.real_sum.detach().to("cpu", torch.float64),
860
+ "real_cov_sum": fid_m.real_cov_sum.detach().to("cpu", torch.float64),
861
+ "fake_sum": fid_m.fake_sum.detach().to("cpu", torch.float64),
862
+ "fake_cov_sum": fid_m.fake_cov_sum.detach().to("cpu", torch.float64),
863
+ "num_real_images": torch.tensor(int(fid_m.num_real_images.item()), dtype=torch.int64),
864
+ "num_fake_images": torch.tensor(int(fid_m.num_fake_images.item()), dtype=torch.int64),
865
+ }
866
+ out_path = tmp_dir / f"fid_sec{s}_rank{my_rank}.pt"
867
+ torch.save(state, out_path)
868
+ if dist_torch.is_initialized():
869
+ dist_torch.barrier()
870
+ if (not dist_torch.is_initialized()) or my_rank == 0:
871
+ for s in sec_list:
872
+ agg = {
873
+ "real_sum": torch.zeros(feature_dim, dtype=torch.float64),
874
+ "real_cov_sum": torch.zeros((feature_dim, feature_dim), dtype=torch.float64),
875
+ "fake_sum": torch.zeros(feature_dim, dtype=torch.float64),
876
+ "fake_cov_sum": torch.zeros((feature_dim, feature_dim), dtype=torch.float64),
877
+ "num_real_images": torch.tensor(0, dtype=torch.int64),
878
+ "num_fake_images": torch.tensor(0, dtype=torch.int64),
879
+ }
880
+ for r in range(world_size):
881
+ p = tmp_dir / f"fid_sec{s}_rank{r}.pt"
882
+ if not p.exists():
883
+ continue
884
+ st = torch.load(p, map_location="cpu")
885
+ agg["real_sum"] += st["real_sum"]
886
+ agg["real_cov_sum"] += st["real_cov_sum"]
887
+ agg["fake_sum"] += st["fake_sum"]
888
+ agg["fake_cov_sum"] += st["fake_cov_sum"]
889
+ agg["num_real_images"] += st["num_real_images"]
890
+ agg["num_fake_images"] += st["num_fake_images"]
891
+ fid_m = fid_loss_fn[s]
892
+ fid_m.real_sum = agg["real_sum"].to(fid_m.device, fid_m.real_sum.dtype)
893
+ fid_m.real_cov_sum = agg["real_cov_sum"].to(fid_m.device, fid_m.real_cov_sum.dtype)
894
+ fid_m.fake_sum = agg["fake_sum"].to(fid_m.device, fid_m.fake_sum.dtype)
895
+ fid_m.fake_cov_sum = agg["fake_cov_sum"].to(fid_m.device, fid_m.fake_cov_sum.dtype)
896
+ fid_m.num_real_images = torch.tensor(
897
+ int(agg["num_real_images"].item()), device=fid_m.device, dtype=fid_m.num_real_images.dtype
898
+ )
899
+ fid_m.num_fake_images = torch.tensor(
900
+ int(agg["num_fake_images"].item()), device=fid_m.device, dtype=fid_m.num_fake_images.dtype
901
+ )
902
+
903
+ try:
904
+ val = float(fid_m.compute().item())
905
+ metric_logger.meters[f'{dataset_name}_{eval_name}_fid_{s}'].update(val, n=1)
906
+ except Exception as e:
907
+ print(f"[WARN] FID compute failed at sec={s}: {e}")
908
+ for s in sec_list:
909
+ for r in range(world_size):
910
+ p = tmp_dir / f"fid_sec{s}_rank{r}.pt"
911
+ try:
912
+ if p.exists():
913
+ p.unlink()
914
+ except Exception:
915
+ pass
916
+ try:
917
+ tmp_dir.rmdir()
918
+ except Exception:
919
+ pass
920
+ if dist_torch.is_initialized():
921
+ dist_torch.barrier()
922
+
923
+ if 'a' in modals and len(fad_streams) > 0:
924
+ for sec in secs:
925
+ try:
926
+ if stereo_mode:
927
+ fad_L, fad_R, fad_avg = fad_streams[sec].compute()
928
+ metric_logger.meters[f'{dataset_name}_{eval_name}_fadL_{sec}'].update(fad_L, n=1)
929
+ metric_logger.meters[f'{dataset_name}_{eval_name}_fadR_{sec}'].update(fad_R, n=1)
930
+ metric_logger.meters[f'{dataset_name}_{eval_name}_fad_{sec}'].update(fad_avg, n=1)
931
+ else:
932
+ fad_val = float(fad_streams[sec].compute())
933
+ metric_logger.meters[f'{dataset_name}_{eval_name}_fad_{sec}'].update(fad_val, n=1)
934
+ except Exception as e:
935
+ if rank == 0:
936
+ print(f"[WARN] FAD compute failed at sec={sec}: {e}")
937
+ continue
938
+
939
+
940
+ # -----------------------------
941
+ # Save
942
+ # -----------------------------
943
+ def save_metric_to_disk(metric_logger, log_p, rank):
944
+ if dist_torch.is_initialized():
945
+ metric_logger.synchronize_between_processes()
946
+ if rank == 0:
947
+ log_stats = {k: float(meter.global_avg) for k, meter in metric_logger.meters.items()}
948
+ os.makedirs(os.path.dirname(log_p), exist_ok=True)
949
+ with open(log_p, 'w') as json_file:
950
+ json.dump(log_stats, json_file, indent=4)
951
+ print(f"[OK] Metrics saved to: {log_p}")
952
+
953
+
954
+ # -----------------------------
955
+ # Main
956
+ # -----------------------------
957
+ def main(args):
958
+ rank, world_size, local_rank = setup_distributed()
959
+ device = f"cuda:{local_rank}" if world_size > 1 else ("cuda" if torch.cuda.is_available() else "cpu")
960
+ torch.backends.cudnn.benchmark = True
961
+
962
+ dataset_name = args.dataset
963
+ secs = np.array([i for i in range(1, 17)], dtype=int)
964
+
965
+ # vision metrics (will only be used if 'v' in modals)
966
+ lpips_loss_fn = get_loss_fn('lpips', secs, device)
967
+ dreamsim_loss_fn = get_loss_fn('dreamsim', secs, device)
968
+ fid_metrics_vision = get_loss_fn('fid', secs, device)
969
+
970
+ try:
971
+ metric_logger = dist.MetricLogger(delimiter=" ")
972
+ if rank == 0:
973
+ print(f"Evaluating {args.eval_name} {dataset_name} | modals = {args.modals}")
974
+
975
+ time_loss_fns = (lpips_loss_fn, dreamsim_loss_fn, fid_metrics_vision)
976
+
977
+ with torch.no_grad():
978
+ evaluate(
979
+ args=args,
980
+ dataset_name=dataset_name,
981
+ eval_type=args.eval_name,
982
+ metric_logger=metric_logger,
983
+ loss_fns=time_loss_fns,
984
+ gt_dir=args.gt_dir,
985
+ exp_dir=args.exp_dir,
986
+ secs=secs,
987
+ device=device,
988
+ rank=rank,
989
+ world_size=world_size,
990
+ modals=args.modals
991
+ )
992
+
993
+ output_fn = os.path.join(args.exp_dir, f'{dataset_name}_{args.eval_name}.json')
994
+ save_metric_to_disk(metric_logger, output_fn, rank)
995
+
996
+ except Exception as e:
997
+ if rank == 0:
998
+ print(e)
999
+ finally:
1000
+ if dist_torch.is_initialized():
1001
+ dist_torch.barrier()
1002
+ dist_torch.destroy_process_group()
1003
+
1004
+
1005
+ # -----------------------------
1006
+ # CLI
1007
+ # -----------------------------
1008
+ if __name__ == "__main__":
1009
+ parser = argparse.ArgumentParser(allow_abbrev=False)
1010
+
1011
+ parser.add_argument("--batch_size", type=int, default=64, help="batch size")
1012
+ parser.add_argument("--gt_dir", type=str, required=True, help="gt directory")
1013
+ parser.add_argument("--exp_dir", type=str, required=True, help="experiment directory (also save json here)")
1014
+ parser.add_argument("--eval_name", type=str, default='time', choices=['time', 'rollout'], help="eval type")
1015
+ parser.add_argument("--dataset", type=str, required=True, help="dataset name (for metric keys & json name)")
1016
+ parser.add_argument("--modals", type=str, default="av", choices=["a", "v", "av"],
1017
+ help="a=audio only (wav), v= image only (png), av=both")
1018
+
1019
+ # FAD options
1020
+ parser.add_argument("--fad_model", type=str, default="vggish",
1021
+ choices=["vggish", "pann", "clap", "encodec"],
1022
+ help="embedding model for FAD")
1023
+ parser.add_argument("--fad_sr", type=int, default=16000,
1024
+ help="sampling rate for FAD")
1025
+
1026
+ # Stereo VGGish FAD options
1027
+ parser.add_argument("--mono", action="store_true",
1028
+ help="default as stereo, add --mono to mono")
1029
+ parser.add_argument("--fad_pad_sec", type=float, default=1.0,
1030
+ help="pad the input of VGGish to x seconds")
1031
+
1032
+ args = parser.parse_args()
1033
+ main(args)
inference_avwm.py ADDED
@@ -0,0 +1,498 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ from distributed import init_distributed
7
+ import torch
8
+ torch.backends.cuda.matmul.allow_tf32 = True
9
+ torch.backends.cudnn.allow_tf32 = True
10
+
11
+ import yaml
12
+ import argparse
13
+ import os
14
+ import numpy as np
15
+
16
+ from diffusion import create_diffusion
17
+ from diffusers.models import AutoencoderKL
18
+
19
+ import misc
20
+ import distributed as dist
21
+ from models import AVCDiT_models
22
+ from datasets import EvalDataset
23
+ from PIL import Image
24
+ from soundstream import SoundStream
25
+ import torchaudio
26
+ from skimage.measure import block_reduce
27
+
28
+ import matplotlib.pyplot as plt
29
+ import librosa
30
+ import time
31
+ import warnings
32
+ warnings.filterwarnings("ignore", category=UserWarning)
33
+ from collections import defaultdict
34
+ import json
35
+
36
+ def save_image(output_file, img, unnormalize_img):
37
+ img = img.detach().cpu()
38
+ if unnormalize_img:
39
+ img = misc.unnormalize(img)
40
+
41
+ img = img * 255
42
+ img = img.byte()
43
+ image = Image.fromarray(img.permute(1, 2, 0).numpy(), mode='RGB')
44
+
45
+ image.save(output_file)
46
+
47
+ def save_audio(output_file, audio_tensor, sample_rate):
48
+ audio_tensor = audio_tensor.detach().cpu()
49
+ if audio_tensor.ndim == 1:
50
+ audio_tensor = audio_tensor.unsqueeze(0)
51
+ torchaudio.save(output_file, audio_tensor.to(torch.float32), sample_rate)
52
+
53
+ def get_dataset_eval(config, dataset_name, eval_type, predefined_index=True):
54
+ data_config = config["eval_datasets"][dataset_name]
55
+ if predefined_index:
56
+ predefined_index = f"data_splits/{dataset_name}/test/{eval_type}.pkl"
57
+ else:
58
+ predefined_index=None
59
+
60
+
61
+ dataset = EvalDataset(
62
+ data_folder=data_config["data_folder"],
63
+ data_split_folder=data_config["test"],
64
+ dataset_name=dataset_name,
65
+ image_size=config["image_size"],
66
+ min_dist_cat=config["eval_distance"]["eval_min_dist_cat"],
67
+ max_dist_cat=config["eval_distance"]["eval_max_dist_cat"],
68
+ len_traj_pred=config["eval_len_traj_pred"],
69
+ traj_stride=config["traj_stride"],
70
+ context_size=config["eval_context_size"],
71
+ normalize=config["normalize"],
72
+ transform=misc.transform,
73
+ goals_per_obs=4,
74
+ predefined_index=predefined_index,
75
+ traj_names='traj_names.txt'
76
+ )
77
+
78
+ return dataset
79
+
80
+
81
+ @torch.no_grad()
82
+ def model_forward_wrapper_v(all_models, curr_obs, curr_delta, num_timesteps, latent_size, device, num_cond, num_goals=1, rel_t=None, progress=False):
83
+ model, diffusion, vae = all_models
84
+ x = curr_obs.to(device)
85
+ y = curr_delta.to(device)
86
+
87
+ with torch.amp.autocast('cuda', enabled=True, dtype=torch.bfloat16):
88
+ B, T = x.shape[:2]
89
+
90
+ if rel_t is None:
91
+ rel_t = (torch.ones(B)* (1. / 128.)).to(device)
92
+ rel_t *= num_timesteps
93
+
94
+ x = x.flatten(0,1)
95
+ x = vae.encode(x).latent_dist.sample().mul_(0.18215).unflatten(0, (B, T))
96
+ x_cond = x[:, :num_cond].unsqueeze(1).expand(B, num_goals, num_cond, x.shape[2], x.shape[3], x.shape[4]).flatten(0, 1)
97
+ z = torch.randn(B*num_goals, 4, latent_size, latent_size, device=device)
98
+ y = y.flatten(0, 1)
99
+ model_kwargs = dict(y=y, x_cond=x_cond, rel_t=rel_t)
100
+ samples = diffusion.p_sample_loop(
101
+ model.forward, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=progress, device=device
102
+ )
103
+ samples = vae.decode(samples / 0.18215).sample
104
+
105
+ return torch.clip(samples, -1., 1.)
106
+
107
+
108
+ @torch.no_grad()
109
+ def model_forward_wrapper_a(all_models, curr_obs, curr_delta, num_timesteps, latent_size, device, num_cond, num_goals=1, rel_t=None, progress=False):
110
+ model, diffusion, sstream = all_models
111
+ x = curr_obs.to(device)
112
+ y = curr_delta.to(device)
113
+ with torch.amp.autocast('cuda', enabled=True, dtype=torch.bfloat16):
114
+ B, T = x.shape[:2]
115
+ if rel_t is None:
116
+ rel_t = (torch.ones(B)* (1. / 128.)).to(device)
117
+ rel_t *= num_timesteps
118
+ x = x.flatten(0,1)
119
+ x = sstream.encoder(x).unflatten(0, (B, T))
120
+ x_cond = x[:, :num_cond].unsqueeze(1).expand(B, num_goals, num_cond, x.shape[2], x.shape[3]).flatten(0, 1)
121
+ z = torch.randn(B*num_goals, 16, 181, device=device)
122
+ y = y.flatten(0, 1)
123
+ model_kwargs = dict(y=y, x_cond=x_cond, rel_t=rel_t)
124
+ samples = diffusion.p_sample_loop(
125
+ model.forward, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=progress, device=device
126
+ )
127
+ # REWARD TOKEN
128
+ patch_tok = samples[..., -1:] # [N, 64, 1]
129
+ diff_pred = patch_tok.mean(dim=1, keepdim=True) # [N, 1]
130
+ samples = samples[..., :-1]
131
+ # AUDIO TOKENS
132
+ quantized, _, _ = sstream.quantizer(samples.permute(0, 2, 1)) # [1, T', D]
133
+ samples = sstream.decoder(quantized.permute(0, 2, 1))
134
+ return samples, diff_pred
135
+
136
+
137
+ @torch.no_grad()
138
+ def model_forward_wrapper_av(all_models, curr_obs, curr_delta, num_timesteps, latent_size, device, num_cond, num_goals=1, rel_t=None, progress=False):
139
+ model, diffusion, vae, sstream = all_models
140
+ x_v, x_a = curr_obs
141
+ x_v = x_v.to(device)
142
+ x_a = x_a.to(device)
143
+ y = curr_delta.to(device)
144
+ with torch.amp.autocast('cuda', enabled=True, dtype=torch.bfloat16):
145
+ B, T_v = x_v.shape[:2]
146
+ B, T_a = x_a.shape[:2]
147
+
148
+ if rel_t is None:
149
+ rel_t = (torch.ones(B)* (1. / 128.)).to(device)
150
+ rel_t *= num_timesteps
151
+ x_v = x_v.flatten(0,1)
152
+ x_a = x_a.flatten(0,1)
153
+ x_v = vae.encode(x_v).latent_dist.sample().mul_(0.18215).unflatten(0, (B, T_v))
154
+ x_a = sstream.encoder(x_a).unflatten(0, (B, T_a))
155
+ x_v_cond = x_v[:, :num_cond].unsqueeze(1).expand(B, num_goals, num_cond, x_v.shape[2], x_v.shape[3], x_v.shape[4]).flatten(0, 1)
156
+ x_a_cond = x_a[:, :num_cond].unsqueeze(1).expand(B, num_goals, num_cond, x_a.shape[2], x_a.shape[3]).flatten(0, 1)
157
+ z_v = torch.randn(B*num_goals, 4, latent_size, latent_size, device=device)
158
+ z_a = torch.randn(B*num_goals, 16, 181, device=device) #TODO
159
+ y = y.flatten(0, 1)
160
+ model_kwargs = dict(y=y, x_v_cond=x_v_cond, x_a_cond=x_a_cond, rel_t=rel_t)
161
+ samples_v, samples_a = diffusion.p_sample_loop(
162
+ model.forward, z_v.shape, z_a.shape, z_v, z_a, clip_denoised=False, model_kwargs=model_kwargs, progress=progress, device=device
163
+ )
164
+ patch_tok = samples_a[..., -1:] # [N, 16, 1]
165
+ diff_pred = patch_tok.mean(dim=1, keepdim=True) # [N, 1]
166
+ samples_a = samples_a[..., :-1]
167
+ samples_v = vae.decode(samples_v / 0.18215).sample
168
+ quantized, _, _ = sstream.quantizer(samples_a.permute(0, 2, 1)) # [1, T', D]
169
+ samples_a = sstream.decoder(quantized.permute(0, 2, 1))
170
+ return torch.clip(samples_v, -1., 1.), samples_a, diff_pred
171
+
172
+
173
+ def generate_rollout(args, output_dir, rollout_frames, idxs, all_models, obs_av, gt_av, diffs_seq, delta, num_cond, device):
174
+ (obs_image, obs_audio, orig_obs_audio)=obs_av
175
+ (gt_image, gt_audio, orig_gt_audio)=gt_av
176
+
177
+ gt_image = gt_image[:,:rollout_frames]
178
+ gt_audio = gt_audio[:,:rollout_frames]
179
+ curr_v = obs_image.to(device)
180
+ curr_a = obs_audio.to(device)
181
+ down_resampler = torchaudio.transforms.Resample(orig_freq=48000, new_freq=16000, lowpass_filter_width=64).to(device, dtype=torch.bfloat16)
182
+ episode_records = defaultdict(list)
183
+ value_key = "denorm_gt" if args.gt else "denorm_pred"
184
+
185
+ for i in range(gt_image.shape[1]):
186
+ curr_delta = delta[:, i:i+1].to(device)
187
+
188
+ x_gt_pixels = gt_image[:, i].to(device)
189
+ x_gt_audios_orig = orig_gt_audio[:, i].to(device)
190
+ if args.gt:
191
+ visualize_preds(output_dir, idxs, i+1, x_gt_pixels, x_gt_audios_orig, 16000)
192
+ denorm_gt_vals = denorm_from_tensor(diffs_seq[:, i:i+1, :]) # [B]
193
+ idxs_1d = idxs.detach().view(-1).cpu().numpy()
194
+ for b, sample_idx in enumerate(idxs_1d):
195
+ episode_records[int(sample_idx)].append({"sec": int(i+1), "value": float(denorm_gt_vals[b])})
196
+ else:
197
+ diff_gt = diffs_seq[:, i:i+1, :].unsqueeze(1).to(device)
198
+ x_pred_pixels, x_pred_audios, diff_pred = model_forward_wrapper_av(all_models, (curr_v, curr_a), curr_delta, num_timesteps=1, latent_size=args.latent_size, device=device, num_cond=num_cond, num_goals=1)
199
+ x_pred_audios_orig = down_resampler(x_pred_audios)
200
+ curr_v = torch.cat((curr_v, x_pred_pixels.unsqueeze(1)), dim=1) # append current prediction
201
+ curr_v = curr_v[:, 1:] # remove first observation
202
+ curr_a = torch.cat((curr_a, x_pred_audios.unsqueeze(1)), dim=1) # append current prediction
203
+ curr_a = curr_a[:, 1:] # remove first observation
204
+ denorm_pred_vals = denorm_from_tensor(diff_pred) # [B]
205
+ denorm_gt_vals = denorm_from_tensor(diff_gt) # [B]
206
+ visualize_preds(output_dir, idxs, i+1, x_pred_pixels, x_pred_audios_orig, 16000)
207
+ visualize_compare(output_dir, idxs, i+1,
208
+ x_pred_pixels, x_pred_audios_orig,
209
+ x_gt_pixels, x_gt_audios_orig,
210
+ denorm_pred_vals=denorm_pred_vals,
211
+ denorm_gt_vals=denorm_gt_vals)
212
+ idxs_1d = idxs.detach().view(-1).cpu().numpy()
213
+ for b, sample_idx in enumerate(idxs_1d):
214
+ episode_records[int(sample_idx)].append({"sec": int(i+1), "value": float(denorm_pred_vals[b])})
215
+
216
+ for sample_idx, rows in episode_records.items():
217
+ rows = sorted(rows, key=lambda r: r["sec"])
218
+ sample_folder = os.path.join(output_dir, f"id_{sample_idx}")
219
+ os.makedirs(sample_folder, exist_ok=True)
220
+ out_json = os.path.join(sample_folder, "distance.json")
221
+ compact = [{ "sec": r["sec"], value_key: r["value"] } for r in rows]
222
+ with open(out_json, "w") as f:
223
+ json.dump(compact, f, indent=2)
224
+
225
+
226
+ def generate_time(args, output_dir, idxs, all_models, obs_av, gt_av, diffs_seq, delta, secs, num_cond, device):
227
+ (obs_image, obs_audio, _)=obs_av
228
+ (gt_image, _, orig_gt_audio)=gt_av
229
+ down_resampler = torchaudio.transforms.Resample(orig_freq=48000, new_freq=16000, lowpass_filter_width=64).to(device, dtype=torch.bfloat16)
230
+ episode_records = defaultdict(list) # {sample_idx: [{"sec": int, "value": float}, ...]}
231
+ value_key = "denorm_gt" if args.gt else "denorm_pred"
232
+
233
+ for sec in secs:
234
+ curr_delta = delta[:, :sec].sum(dim=1, keepdim=True)
235
+ x_gt_pixels = gt_image[:, sec-1].to(device)
236
+ x_gt_audios_orig = orig_gt_audio[:, sec-1].to(device)
237
+ if args.gt:
238
+ denorm_gt_vals = denorm_from_tensor(diffs_seq[:, :sec, :].sum(dim=1, keepdim=True)) # [B]
239
+ visualize_preds(output_dir, idxs, sec, x_gt_pixels, x_gt_audios_orig, 16000)
240
+ idxs_1d = idxs.detach().view(-1).cpu().numpy()
241
+ for b, sample_idx in enumerate(idxs_1d):
242
+ episode_records[int(sample_idx)].append({"sec": int(sec), "value": float(denorm_gt_vals[b])})
243
+ else:
244
+ diff_gt = diffs_seq[:, :sec, :].sum(dim=1, keepdim=True).to(device)
245
+
246
+ print(obs_image.shape, obs_audio.shape, curr_delta.shape, obs_image.dtype, obs_audio.dtype, curr_delta.dtype)
247
+ x_pred_pixels, x_pred_audios, diff_pred = model_forward_wrapper_av(all_models, (obs_image, obs_audio) , curr_delta, sec, args.latent_size, num_cond=num_cond, num_goals=1, device=device)
248
+ x_pred_audios_orig = down_resampler(x_pred_audios)
249
+ denorm_pred_vals = denorm_from_tensor(diff_pred) # [B]
250
+ denorm_gt_vals = denorm_from_tensor(diff_gt) # [B]
251
+
252
+ visualize_preds(output_dir, idxs, sec, x_pred_pixels, x_pred_audios_orig, 16000)
253
+ visualize_compare(output_dir, idxs, sec,
254
+ x_pred_pixels, x_pred_audios_orig,
255
+ x_gt_pixels, x_gt_audios_orig,
256
+ denorm_pred_vals=denorm_pred_vals,
257
+ denorm_gt_vals=denorm_gt_vals)
258
+ idxs_1d = idxs.detach().view(-1).cpu().numpy()
259
+ for b, sample_idx in enumerate(idxs_1d):
260
+ episode_records[int(sample_idx)].append({"sec": int(sec), "value": float(denorm_pred_vals[b])})
261
+ for sample_idx, rows in episode_records.items():
262
+ rows = sorted(rows, key=lambda r: r["sec"])
263
+ sample_folder = os.path.join(output_dir, f"id_{sample_idx}")
264
+ os.makedirs(sample_folder, exist_ok=True)
265
+ out_json = os.path.join(sample_folder, "distance.json")
266
+ compact = [{ "sec": r["sec"], value_key: r["value"] } for r in rows]
267
+ with open(out_json, "w") as f:
268
+ json.dump(compact, f, indent=2)
269
+
270
+
271
+ def visualize_preds(output_dir, idxs, sec, x_pred_pixels, x_pred_audios, sample_rate):
272
+ idxs_1d = idxs.detach().view(-1)
273
+ for batch_idx, sample_idx in enumerate(idxs_1d):
274
+ sample_idx = int(sample_idx.item())
275
+ sample_folder = os.path.join(output_dir, f'id_{sample_idx}')
276
+ os.makedirs(sample_folder, exist_ok=True)
277
+ image_file = os.path.join(sample_folder, f'{sec}.png')
278
+ save_image(image_file, x_pred_pixels[batch_idx], True)
279
+ audio_file = os.path.join(sample_folder, f'{sec}.wav')
280
+ save_audio(audio_file, x_pred_audios[batch_idx], sample_rate)
281
+
282
+ def _compute_binaural_spectrogram_np(audio_2ch: np.ndarray):
283
+ def _stft_abs(signal):
284
+ n_fft = 512
285
+ hop_length = 160
286
+ win_length = 400
287
+ stft = np.abs(librosa.stft(signal, n_fft=n_fft, hop_length=hop_length, win_length=win_length))
288
+ stft = block_reduce(stft, block_size=(4, 4), func=np.mean)
289
+ return stft
290
+ L = np.log1p(_stft_abs(audio_2ch[0]))
291
+ R = np.log1p(_stft_abs(audio_2ch[1]))
292
+ spec = np.stack([L, R], axis=-1) # (F,T,2)
293
+ return spec
294
+
295
+ def denorm_from_tensor(t: torch.Tensor, min_v=-20.0, max_v=20.0, scale=0.15) -> torch.Tensor:
296
+ x = t.detach().float().view(t.shape[0], -1)[:, 0]
297
+ n01 = (x + 1.0) / 2.0
298
+ raw = n01 * (max_v - min_v) + min_v
299
+ return raw * scale
300
+
301
+ def visualize_compare(output_dir, idxs, sec,
302
+ x_pred_pixels, x_pred_audios_orig,
303
+ x_gt_pixels, x_gt_audios_orig,
304
+ denorm_pred_vals,
305
+ denorm_gt_vals):
306
+ idxs_np = idxs.detach().view(-1).cpu().numpy()
307
+
308
+ B = x_pred_pixels.shape[0]
309
+ assert x_gt_pixels.shape[0] == B and x_pred_audios_orig.shape[0] == B and x_gt_audios_orig.shape[0] == B
310
+
311
+ for b in range(B):
312
+ sample_idx = int(idxs_np[b])
313
+ sample_folder = os.path.join(output_dir, f'id_{sample_idx}')
314
+ os.makedirs(sample_folder, exist_ok=True)
315
+ out_path = os.path.join(sample_folder, f'compare_{sec}.png')
316
+ def _tensor_to_display_img(x: torch.Tensor):
317
+ x = x.detach().cpu()
318
+ x = misc.unnormalize(x)
319
+ x = (x * 255.0).round().clamp(0, 255)
320
+ x = x.to(torch.uint8).permute(1, 2, 0)
321
+ return x.numpy()
322
+
323
+ pred_img = _tensor_to_display_img(x_pred_pixels[b])
324
+ gt_img = _tensor_to_display_img(x_gt_pixels[b])
325
+
326
+ pred_aud = x_pred_audios_orig[b].detach().cpu().float().numpy()
327
+ gt_aud = x_gt_audios_orig[b].detach().cpu().float().numpy()
328
+ pred_spec = _compute_binaural_spectrogram_np(pred_aud)
329
+ gt_spec = _compute_binaural_spectrogram_np(gt_aud)
330
+
331
+ vmin_L = min(pred_spec[:, :, 0].min(), gt_spec[:, :, 0].min())
332
+ vmax_L = max(pred_spec[:, :, 0].max(), gt_spec[:, :, 0].max())
333
+ vmin_R = min(pred_spec[:, :, 1].min(), gt_spec[:, :, 1].min())
334
+ vmax_R = max(pred_spec[:, :, 1].max(), gt_spec[:, :, 1].max())
335
+
336
+ dn_pred = float(denorm_pred_vals[b]) if denorm_pred_vals is not None else 0
337
+ dn_gt = float(denorm_gt_vals[b]) if denorm_gt_vals is not None else 0
338
+
339
+ fig, axes = plt.subplots(2, 4, figsize=(14, 6), constrained_layout=True)
340
+
341
+ axes[0, 0].imshow(pred_img); axes[0, 0].set_title('pred image'); axes[0, 0].axis('off')
342
+ axes[0, 1].imshow(gt_img); axes[0, 1].set_title('gt image'); axes[0, 1].axis('off')
343
+
344
+ axes[1, 0].axis('off')
345
+ axes[1, 1].axis('off')
346
+
347
+ im_pred_L = axes[0, 2].imshow(pred_spec[:, :, 0], origin='lower', aspect='auto', vmin=vmin_L, vmax=vmax_L)
348
+ axes[0, 2].set_title('pred spec (Left)'); axes[0, 2].set_xticks([]); axes[0, 2].set_yticks([])
349
+ im_gt_L = axes[0, 3].imshow(gt_spec[:, :, 0], origin='lower', aspect='auto', vmin=vmin_L, vmax=vmax_L)
350
+ axes[0, 3].set_title('gt spec (Left)'); axes[0, 3].set_xticks([]); axes[0, 3].set_yticks([])
351
+ im_pred_R = axes[1, 2].imshow(pred_spec[:, :, 1], origin='lower', aspect='auto', vmin=vmin_R, vmax=vmax_R)
352
+ axes[1, 2].set_title('pred spec (Right)'); axes[1, 2].set_xticks([]); axes[1, 2].set_yticks([])
353
+ im_gt_R = axes[1, 3].imshow(gt_spec[:, :, 1], origin='lower', aspect='auto', vmin=vmin_R, vmax=vmax_R)
354
+ axes[1, 3].set_title('gt spec (Right)'); axes[1, 3].set_xticks([]); axes[1, 3].set_yticks([])
355
+
356
+ fig.suptitle(
357
+ f'id={sample_idx}, sec={sec} | denorm(reward_pred)={dn_pred:.4f}, denorm(reward_gt)={dn_gt:.4f}',
358
+ fontsize=11
359
+ )
360
+ plt.savefig(out_path, dpi=180)
361
+ plt.close(fig)
362
+
363
+
364
+ @torch.no_grad()
365
+ def main(args):
366
+ _, _, device, _ = init_distributed()
367
+ print(args)
368
+ device = torch.device(device)
369
+ num_tasks = dist.get_world_size()
370
+ global_rank = dist.get_rank()
371
+ exp_eval = args.exp
372
+
373
+ # model & config setup
374
+ if args.gt:
375
+ args.save_output_dir = os.path.join(args.output_dir, 'gt')
376
+ else:
377
+ exp_name = os.path.basename(exp_eval).split('.')[0]
378
+ args.save_output_dir = os.path.join(args.output_dir, exp_name)
379
+
380
+ if args.ckp != '0100000':
381
+ args.save_output_dir = args.save_output_dir + "_%s"%(args.ckp)
382
+
383
+ os.makedirs(args.save_output_dir, exist_ok=True)
384
+
385
+ with open("config/eval_config.yaml", "r") as f:
386
+ default_config = yaml.safe_load(f)
387
+ config = default_config
388
+
389
+ with open(exp_eval, "r") as f:
390
+ user_config = yaml.safe_load(f)
391
+ config.update(user_config)
392
+
393
+ eval_len_traj_pred=config["eval_len_traj_pred"]
394
+ if args.rollout_frames==-1:
395
+ args.rollout_frames=eval_len_traj_pred
396
+ assert args.rollout_frames<=eval_len_traj_pred
397
+ latent_size = config['image_size'] // 8
398
+ args.latent_size = config['image_size'] // 8
399
+
400
+ num_cond = config['context_size']
401
+ print("loading")
402
+ model_lst = (None, None, None, None)
403
+ if not args.gt:
404
+ model = AVCDiT_models[config['model']](context_size=num_cond, input_size=latent_size, in_channels=4, mode="av")
405
+ ckp = torch.load(f'{config["results_dir"]}/{config["run_name"]}/checkpoints/{args.ckp}.pth.tar', map_location='cpu', weights_only=False)
406
+ print(model.load_state_dict(ckp["ema"], strict=True))
407
+ model.eval()
408
+ model.to(device)
409
+ model = torch.compile(model)
410
+ diffusion = create_diffusion(str(250), dual=True)
411
+ vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-ema").to(device)
412
+
413
+ sstream = SoundStream(C=32, D=16, n_q=8, codebook_size=1024).to(device)
414
+ sstream_path=config["tokenizer_a_path"]
415
+ sstream_checkpoint = torch.load(sstream_path, map_location=device)
416
+ sstream.load_state_dict(sstream_checkpoint["model_state"])
417
+ sstream.eval()
418
+
419
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device], find_unused_parameters=False)
420
+ model_lst = (model, diffusion, vae, sstream)
421
+
422
+ # Loading Datasets
423
+ dataset_names = args.datasets.split(',')
424
+ datasets = {}
425
+
426
+ for dataset_name in dataset_names:
427
+ dataset_val = get_dataset_eval(config, dataset_name, args.eval_type, predefined_index=False)
428
+
429
+ if len(dataset_val) % num_tasks != 0:
430
+ print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
431
+ 'This will slightly alter validation results as extra duplicate entries are added to achieve '
432
+ 'equal num of samples per-process.')
433
+ sampler_val = torch.utils.data.DistributedSampler(
434
+ dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False)
435
+
436
+ curr_data_loader = torch.utils.data.DataLoader(
437
+ dataset_val, sampler=sampler_val,
438
+ batch_size=args.batch_size,
439
+ num_workers=args.num_workers,
440
+ pin_memory=True,
441
+ drop_last=False
442
+ )
443
+ datasets[dataset_name] = curr_data_loader
444
+
445
+ print_freq = 1
446
+ header = 'Evaluation: '
447
+ metric_logger = dist.MetricLogger(delimiter=" ")
448
+
449
+ for dataset_name in dataset_names:
450
+ dataset_save_output_dir = os.path.join(args.save_output_dir, dataset_name)
451
+ os.makedirs(dataset_save_output_dir, exist_ok=True)
452
+ curr_data_loader = datasets[dataset_name]
453
+
454
+ for data_iter_step, (idxs, obs_image, gt_image, obs_audio, gt_audio, diffs_seq, delta, orig_obs_audio, orig_gt_audio) in enumerate(metric_logger.log_every(curr_data_loader, print_freq, header)):
455
+ with torch.amp.autocast('cuda', enabled=True, dtype=torch.bfloat16):
456
+ obs_image = obs_image[:, -num_cond:].to(device)
457
+ gt_image = gt_image.to(device)
458
+ obs_audio = obs_audio[:, -num_cond:].to(device)
459
+ gt_audio = gt_audio.to(device)
460
+ orig_obs_audio = orig_obs_audio[:, -num_cond:].to(device)
461
+ orig_gt_audio = orig_gt_audio.to(device)
462
+
463
+ diffs_seq = diffs_seq.to(device)
464
+ obs_av=(obs_image, obs_audio, orig_obs_audio)
465
+ gt_av=(gt_image, gt_audio, orig_gt_audio)
466
+ if args.eval_type == 'rollout':
467
+ curr_rollout_output_dir = os.path.join(dataset_save_output_dir, f'rollout_{args.rollout_frames}frames')
468
+ os.makedirs(curr_rollout_output_dir, exist_ok=True)
469
+ generate_rollout(args, curr_rollout_output_dir, args.rollout_frames, idxs, model_lst, obs_av, gt_av, diffs_seq, delta, num_cond, device)
470
+ elif args.eval_type == 'time':
471
+ if args.time_secs != '':
472
+ secs = np.array([int(sec) for sec in args.time_secs.split(',')])
473
+ else:
474
+ secs = np.array([int(sec) for sec in range(1,args.rollout_frames+1)])
475
+ curr_time_output_dir = os.path.join(dataset_save_output_dir, 'time')
476
+ os.makedirs(curr_time_output_dir, exist_ok=True)
477
+ generate_time(args, curr_time_output_dir, idxs, model_lst, obs_av, gt_av, diffs_seq, delta, secs, num_cond, device)
478
+
479
+
480
+ if __name__ == "__main__":
481
+ parser = argparse.ArgumentParser()
482
+
483
+ parser.add_argument("--output_dir", type=str, default=None, help="output directory")
484
+ parser.add_argument("--exp", type=str, default=None, help="experiment name")
485
+ parser.add_argument("--ckp", type=str, default='0100000')
486
+ parser.add_argument("--num_sec_eval", type=int, default=5)
487
+ parser.add_argument("--input_fps", type=int, default=4)
488
+ parser.add_argument("--datasets", type=str, default=None, help="dataset name")
489
+ parser.add_argument("--num_workers", type=int, default=8, help="num workers")
490
+ parser.add_argument("--batch_size", type=int, default=16, help="batch size")
491
+ parser.add_argument("--eval_type", type=str, default=None, help="type of evaluation has to be either 'time' or 'rollout'")
492
+ # Rollout Evaluation Args
493
+ parser.add_argument("--time_secs", type=str, default='', help="") #'1,2,3,4'
494
+ parser.add_argument("--rollout_frames", type=int, default=-1, help="")
495
+ parser.add_argument("--gt", type=int, default=0, help="set to 1 to produce ground truth evaluation set")
496
+ args = parser.parse_args()
497
+
498
+ main(args)
mel_scale.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import Tensor
3
+ from typing import Optional
4
+ import math
5
+
6
+ import warnings
7
+
8
+ class MelScale(torch.nn.Module):
9
+ r"""Turn a normal STFT into a mel frequency STFT, using a conversion
10
+ matrix. This uses triangular filter banks.
11
+
12
+ User can control which device the filter bank (`fb`) is (e.g. fb.to(spec_f.device)).
13
+
14
+ Args:
15
+ n_mels (int, optional): Number of mel filterbanks. (Default: ``128``)
16
+ sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
17
+ f_min (float, optional): Minimum frequency. (Default: ``0.``)
18
+ f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``)
19
+ n_stft (int, optional): Number of bins in STFT. See ``n_fft`` in :class:`Spectrogram`. (Default: ``201``)
20
+ norm (str or None, optional): If 'slaney', divide the triangular mel weights by the width of the mel band
21
+ (area normalization). (Default: ``None``)
22
+ mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
23
+
24
+ See also:
25
+ :py:func:`torchaudio.functional.melscale_fbanks` - The function used to
26
+ generate the filter banks.
27
+ """
28
+ __constants__ = ['n_mels', 'sample_rate', 'f_min', 'f_max']
29
+
30
+ def __init__(self,
31
+ n_mels: int = 128,
32
+ sample_rate: int = 16000,
33
+ f_min: float = 0.,
34
+ f_max: Optional[float] = None,
35
+ n_stft: int = 201,
36
+ norm: Optional[str] = None,
37
+ mel_scale: str = "htk") -> None:
38
+ super(MelScale, self).__init__()
39
+ self.n_mels = n_mels
40
+ self.sample_rate = sample_rate
41
+ self.f_max = f_max if f_max is not None else float(sample_rate // 2)
42
+ self.f_min = f_min
43
+ self.norm = norm
44
+ self.mel_scale = mel_scale
45
+
46
+ assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max)
47
+ fb = melscale_fbanks(
48
+ n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, self.norm,
49
+ self.mel_scale)
50
+ self.register_buffer('fb', fb)
51
+
52
+ def forward(self, specgram: Tensor) -> Tensor:
53
+ r"""
54
+ Args:
55
+ specgram (Tensor): A spectrogram STFT of dimension (..., freq, time).
56
+
57
+ Returns:
58
+ Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time).
59
+ """
60
+
61
+ # (..., time, freq) dot (freq, n_mels) -> (..., n_mels, time)
62
+ mel_specgram = torch.matmul(specgram.transpose(-1, -2), self.fb).transpose(-1, -2)
63
+
64
+ return mel_specgram
65
+
66
+ def _hz_to_mel(freq: float, mel_scale: str = "htk") -> float:
67
+ r"""Convert Hz to Mels.
68
+
69
+ Args:
70
+ freqs (float): Frequencies in Hz
71
+ mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
72
+
73
+ Returns:
74
+ mels (float): Frequency in Mels
75
+ """
76
+
77
+ if mel_scale not in ['slaney', 'htk']:
78
+ raise ValueError('mel_scale should be one of "htk" or "slaney".')
79
+
80
+ if mel_scale == "htk":
81
+ return 2595.0 * math.log10(1.0 + (freq / 700.0))
82
+
83
+ # Fill in the linear part
84
+ f_min = 0.0
85
+ f_sp = 200.0 / 3
86
+
87
+ mels = (freq - f_min) / f_sp
88
+
89
+ # Fill in the log-scale part
90
+ min_log_hz = 1000.0
91
+ min_log_mel = (min_log_hz - f_min) / f_sp
92
+ logstep = math.log(6.4) / 27.0
93
+
94
+ if freq >= min_log_hz:
95
+ mels = min_log_mel + math.log(freq / min_log_hz) / logstep
96
+
97
+ return mels
98
+
99
+ def _mel_to_hz(mels: Tensor, mel_scale: str = "htk") -> Tensor:
100
+ """Convert mel bin numbers to frequencies.
101
+
102
+ Args:
103
+ mels (Tensor): Mel frequencies
104
+ mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
105
+
106
+ Returns:
107
+ freqs (Tensor): Mels converted in Hz
108
+ """
109
+
110
+ if mel_scale not in ['slaney', 'htk']:
111
+ raise ValueError('mel_scale should be one of "htk" or "slaney".')
112
+
113
+ if mel_scale == "htk":
114
+ return 700.0 * (10.0**(mels / 2595.0) - 1.0)
115
+
116
+ # Fill in the linear scale
117
+ f_min = 0.0
118
+ f_sp = 200.0 / 3
119
+ freqs = f_min + f_sp * mels
120
+
121
+ # And now the nonlinear scale
122
+ min_log_hz = 1000.0
123
+ min_log_mel = (min_log_hz - f_min) / f_sp
124
+ logstep = math.log(6.4) / 27.0
125
+
126
+ log_t = (mels >= min_log_mel)
127
+ freqs[log_t] = min_log_hz * torch.exp(logstep * (mels[log_t] - min_log_mel))
128
+
129
+ return freqs
130
+
131
+ def _create_triangular_filterbank(
132
+ all_freqs: Tensor,
133
+ f_pts: Tensor,
134
+ ) -> Tensor:
135
+ """Create a triangular filter bank.
136
+
137
+ Args:
138
+ all_freqs (Tensor): STFT freq points of size (`n_freqs`).
139
+ f_pts (Tensor): Filter mid points of size (`n_filter`).
140
+
141
+ Returns:
142
+ fb (Tensor): The filter bank of size (`n_freqs`, `n_filter`).
143
+ """
144
+ # Adopted from Librosa
145
+ # calculate the difference between each filter mid point and each stft freq point in hertz
146
+ f_diff = f_pts[1:] - f_pts[:-1] # (n_filter + 1)
147
+ slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1) # (n_freqs, n_filter + 2)
148
+ # create overlapping triangles
149
+ zero = torch.zeros(1)
150
+ down_slopes = (-1.0 * slopes[:, :-2]) / f_diff[:-1] # (n_freqs, n_filter)
151
+ up_slopes = slopes[:, 2:] / f_diff[1:] # (n_freqs, n_filter)
152
+ fb = torch.max(zero, torch.min(down_slopes, up_slopes))
153
+
154
+ return fb
155
+
156
+ def melscale_fbanks(
157
+ n_freqs: int,
158
+ f_min: float,
159
+ f_max: float,
160
+ n_mels: int,
161
+ sample_rate: int,
162
+ norm: Optional[str] = None,
163
+ mel_scale: str = "htk",
164
+ ) -> Tensor:
165
+ r"""Create a frequency bin conversion matrix.
166
+
167
+ Note:
168
+ For the sake of the numerical compatibility with librosa, not all the coefficients
169
+ in the resulting filter bank has magnitude of 1.
170
+
171
+ .. image:: https://download.pytorch.org/torchaudio/doc-assets/mel_fbanks.png
172
+ :alt: Visualization of generated filter bank
173
+
174
+ Args:
175
+ n_freqs (int): Number of frequencies to highlight/apply
176
+ f_min (float): Minimum frequency (Hz)
177
+ f_max (float): Maximum frequency (Hz)
178
+ n_mels (int): Number of mel filterbanks
179
+ sample_rate (int): Sample rate of the audio waveform
180
+ norm (str or None, optional): If 'slaney', divide the triangular mel weights by the width of the mel band
181
+ (area normalization). (Default: ``None``)
182
+ mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
183
+
184
+ Returns:
185
+ Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_mels``)
186
+ meaning number of frequencies to highlight/apply to x the number of filterbanks.
187
+ Each column is a filterbank so that assuming there is a matrix A of
188
+ size (..., ``n_freqs``), the applied result would be
189
+ ``A * melscale_fbanks(A.size(-1), ...)``.
190
+
191
+ """
192
+
193
+ if norm is not None and norm != "slaney":
194
+ raise ValueError("norm must be one of None or 'slaney'")
195
+
196
+ # freq bins
197
+ all_freqs = torch.linspace(0, sample_rate // 2, n_freqs)
198
+
199
+ # calculate mel freq bins
200
+ m_min = _hz_to_mel(f_min, mel_scale=mel_scale)
201
+ m_max = _hz_to_mel(f_max, mel_scale=mel_scale)
202
+
203
+ m_pts = torch.linspace(m_min, m_max, n_mels + 2)
204
+ f_pts = _mel_to_hz(m_pts, mel_scale=mel_scale)
205
+
206
+ # create filterbank
207
+ fb = _create_triangular_filterbank(all_freqs, f_pts)
208
+
209
+ if norm is not None and norm == "slaney":
210
+ # Slaney-style mel is scaled to be approx constant energy per channel
211
+ enorm = 2.0 / (f_pts[2:n_mels + 2] - f_pts[:n_mels])
212
+ fb *= enorm.unsqueeze(0)
213
+
214
+ if (fb.max(dim=0).values == 0.).any():
215
+ warnings.warn(
216
+ "At least one mel filterbank has all zero values. "
217
+ f"The value for `n_mels` ({n_mels}) may be set too high. "
218
+ f"Or, the value for `n_freqs` ({n_freqs}) may be set too low."
219
+ )
220
+
221
+ return fb
merge_experts.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import yaml
3
+ import argparse
4
+ from models import AVCDiT_models
5
+
6
+
7
+ def add_exact_keys(mapping, keys):
8
+ for k in keys:
9
+ mapping[k] = k
10
+
11
+
12
+ def add_mlp_block_keys(mapping, mlp_name, num_blocks):
13
+ for i in range(num_blocks):
14
+ for fc in ["fc1", "fc2"]:
15
+ for param in ["weight", "bias"]:
16
+ k = f"blocks.{i}.{mlp_name}.{fc}.{param}"
17
+ mapping[k] = k
18
+
19
+
20
+ def load_from_two_checkpoints(model, ckpt1_path, ckpt2_path, map1=None, map2=None, device='cuda'):
21
+ ckpt1 = torch.load(ckpt1_path, map_location=device, weights_only=False)
22
+ ckpt2 = torch.load(ckpt2_path, map_location=device, weights_only=False)
23
+
24
+ state1 = {k.replace('_orig_mod.', ''): v for k, v in ckpt1["ema"].items()}
25
+ state2 = {k.replace('_orig_mod.', ''): v for k, v in ckpt2["ema"].items()}
26
+
27
+ model_state = model.state_dict()
28
+
29
+ new_state = {}
30
+ source_info = {} # key: model param name, value: ckpt source name
31
+
32
+ if map1:
33
+ for k_model, k_ckpt in map1.items():
34
+ if (
35
+ k_ckpt in state1
36
+ and k_model in model_state
37
+ and state1[k_ckpt].shape == model_state[k_model].shape
38
+ ):
39
+ new_state[k_model] = state1[k_ckpt]
40
+ source_info[k_model] = "ckpt1"
41
+
42
+ if map2:
43
+ for k_model, k_ckpt in map2.items():
44
+ if (
45
+ k_ckpt in state2
46
+ and k_model in model_state
47
+ and state2[k_ckpt].shape == model_state[k_model].shape
48
+ ):
49
+ new_state[k_model] = state2[k_ckpt]
50
+ source_info[k_model] = "ckpt2"
51
+
52
+ for k_model, tensor in model_state.items():
53
+ if k_model not in new_state:
54
+ if k_model in state1 and state1[k_model].shape == tensor.shape:
55
+ new_state[k_model] = state1[k_model]
56
+ source_info[k_model] = "fallback_ckpt1"
57
+
58
+ model.load_state_dict(new_state, strict=False)
59
+ print(f"Loaded {len(new_state)} / {len(model_state)} parameters")
60
+
61
+ return new_state
62
+
63
+
64
+ def main(args):
65
+ with open(args.config, "r") as f:
66
+ config = yaml.safe_load(f)
67
+
68
+ model_name = config.get("model", "AVCDiT-B/2")
69
+ print(f"Using model: {model_name}")
70
+
71
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
72
+
73
+ model = AVCDiT_models[model_name](
74
+ context_size=4,
75
+ input_size=28,
76
+ in_channels=4,
77
+ mode="av"
78
+ ).to(device)
79
+
80
+ depth = len(model.blocks)
81
+
82
+ map1 = {}
83
+ add_exact_keys(map1, [
84
+ "pos_embed_v",
85
+ "x_embedder_v.proj.weight",
86
+ "x_embedder_v.proj.bias",
87
+ "final_layer.linear.weight",
88
+ "final_layer.linear.bias",
89
+ "final_layer.adaLN_modulation.1.weight",
90
+ "final_layer.adaLN_modulation.1.bias",
91
+ ])
92
+ add_mlp_block_keys(map1, "mlp_v", depth)
93
+
94
+ map2 = {}
95
+ add_exact_keys(map2, [
96
+ "pos_embed_a_cond",
97
+ "pos_embed_a_pred",
98
+ "x_embedder_a.weight",
99
+ "x_embedder_a.bias",
100
+ "final_layer_a.linear.weight",
101
+ "final_layer_a.linear.bias",
102
+ "final_layer_a.adaLN_modulation.1.weight",
103
+ "final_layer_a.adaLN_modulation.1.bias",
104
+ ])
105
+ add_mlp_block_keys(map2, "mlp_a", depth)
106
+
107
+ merged_state_dict = load_from_two_checkpoints(
108
+ model,
109
+ ckpt1_path=args.v_expert,
110
+ ckpt2_path=args.a_expert,
111
+ map1=map1,
112
+ map2=map2,
113
+ device=device
114
+ )
115
+
116
+ torch.save({"ema": merged_state_dict}, args.output)
117
+ print(f"Merged model saved to {args.output}")
118
+
119
+
120
+ if __name__ == "__main__":
121
+ parser = argparse.ArgumentParser()
122
+ parser.add_argument("--config", type=str, required=True)
123
+ parser.add_argument("--v_expert", type=str, required=True)
124
+ parser.add_argument("--a_expert", type=str, required=True)
125
+ parser.add_argument("--output", type=str, default="experts_merged.pth")
126
+ args = parser.parse_args()
127
+
128
+ main(args)
misc.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ import matplotlib.pyplot as plt
3
+ import torch
4
+ import numpy as np
5
+ import os
6
+ from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
7
+ from PIL import Image
8
+ from torchvision import transforms
9
+ import torchvision.transforms.functional as TF
10
+
11
+
12
+ IMAGE_ASPECT_RATIO = (4 / 3) # all images are centered cropped to a 4:3 aspect ratio in training
13
+
14
+ with open("config/data_config.yaml", "r") as f:
15
+ data_config = yaml.safe_load(f)
16
+
17
+
18
+ def get_action_torch(diffusion_output, action_stats):
19
+ ndeltas = diffusion_output
20
+ ndeltas = ndeltas.reshape(ndeltas.shape[0], -1, 2)
21
+ ndeltas = unnormalize_data(ndeltas, action_stats)
22
+ actions = torch.cumsum(ndeltas, dim=1)
23
+ return actions.to(ndeltas)
24
+
25
+ def log_viz_single(dataset_name, obs_image, goal_image, preds, deltas, loss, min_idx, actions, action_stats, plan_iter=0, output_dir='plot.png'):
26
+ '''
27
+ Visualize a single instance
28
+ actions is gt actions
29
+ '''
30
+ viz_obs_image = unnormalize(obs_image.detach().cpu())[-1] # take last img
31
+ viz_goal_image = unnormalize(goal_image.detach().cpu())
32
+ deltas = deltas.detach().cpu()
33
+ loss = loss.detach().cpu()
34
+ actions = actions.detach().cpu()
35
+ pred_actions = get_action_torch(deltas[:, :, :2], action_stats)
36
+ plot_array = plot_images_and_actions(dataset_name, viz_obs_image, viz_goal_image, pred_actions, actions, min_idx, loss=loss)
37
+
38
+ plt.imshow(plot_array)
39
+ plt.axis('off') # Hide axes for a cleaner image
40
+
41
+ # Save the plot array as a PNG file locally
42
+ plt.savefig(output_dir, format='png', dpi=300, bbox_inches='tight')
43
+
44
+ def plot_images_and_actions(dataset_name, curr_viz_obs_image, curr_viz_goal_image, curr_viz_pred_actions, curr_viz_actions, min_idx, loss):
45
+ curr_viz_obs_image = curr_viz_obs_image.permute(1, 2, 0).cpu().numpy()
46
+ curr_viz_goal_image = curr_viz_goal_image.permute(1, 2, 0).cpu().numpy()
47
+
48
+ # scale back to metric space for plotting
49
+ curr_viz_pred_actions = curr_viz_pred_actions * data_config[dataset_name]['metric_waypoint_spacing']
50
+ curr_viz_actions = curr_viz_actions * data_config[dataset_name]['metric_waypoint_spacing']
51
+
52
+ # Create the figure with three subplots
53
+ fig, axs = plt.subplots(1, 3, figsize=(9, 3))
54
+
55
+ # Plot condition image
56
+ axs[0].imshow(curr_viz_obs_image)
57
+ axs[0].set_title("Condition Image", fontsize=13)
58
+ axs[0].axis("off")
59
+
60
+ # Plot goal image
61
+ axs[1].imshow(curr_viz_goal_image)
62
+ axs[1].set_title("Goal Image", fontsize=13)
63
+ axs[1].axis("off")
64
+
65
+ colors = ['red', 'orange', 'cyan']
66
+ for i in range(1, curr_viz_pred_actions.shape[0]):
67
+ color = colors[(i - 1) % len(colors)]
68
+ label = f"Sample {i} Min Loss" if i == min_idx.item() else f"{i}"
69
+
70
+ if i != min_idx.item():
71
+ axs[2].plot(-curr_viz_pred_actions[i, :, 1], curr_viz_pred_actions[i, :, 0],
72
+ color=color, marker="o", markersize=5, label=label)
73
+ axs[2].text(-curr_viz_pred_actions[i, -1, 1],
74
+ curr_viz_pred_actions[i, -1, 0],
75
+ round(loss[i].item(), 3),
76
+ color='black',
77
+ fontsize=10,
78
+ ha='left', va='bottom') # Adjust position to avoid overlap
79
+
80
+ # Highlight the minimum loss sample
81
+ axs[2].plot(-curr_viz_pred_actions[min_idx.item(), :, 1], curr_viz_pred_actions[min_idx.item(), :, 0],
82
+ color='green', marker="o", markersize=5, label=f"{min_idx.item()}")
83
+ axs[2].text(-curr_viz_pred_actions[min_idx.item(), -1, 1],
84
+ curr_viz_pred_actions[min_idx.item(), -1, 0],
85
+ round(loss[min_idx.item()].item(), 3),
86
+ color='black',
87
+ fontsize=10,
88
+ ha='left', va='bottom') # Adjust position to avoid overlap
89
+
90
+ # Plot ground truth actions
91
+ axs[2].plot(-curr_viz_actions[:, 1], curr_viz_actions[:, 0], color='blue', marker="o", label="GT")
92
+
93
+ # Set titles and labels with larger font size
94
+ axs[2].set_title(" ", fontsize=13)
95
+ axs[2].set_xlabel("X (m)", fontsize=11)
96
+ axs[2].set_ylabel("Y (m)", fontsize=11)
97
+
98
+ # Set equal aspect ratio and adjust axis limits
99
+ axs[2].set_aspect('equal', adjustable='box')
100
+ x_min, x_max = axs[2].get_xlim()
101
+ y_min, y_max = axs[2].get_ylim()
102
+ axis_range = max(x_max - x_min, y_max - y_min) / 2
103
+ x_mid = (x_max + x_min) / 2
104
+ y_mid = (y_max + y_min) / 2
105
+ axs[2].set_xlim(x_mid - axis_range, x_mid + axis_range)
106
+ axs[2].set_ylim(y_mid - axis_range, y_mid + axis_range)
107
+
108
+ axs[2].legend(loc='lower left', fontsize=10, frameon=True, bbox_to_anchor=(0, 0))
109
+ plt.tight_layout()
110
+
111
+ canvas = FigureCanvas(fig)
112
+ canvas.draw()
113
+ plot_array = np.frombuffer(canvas.tostring_rgb(), dtype='uint8')
114
+ plot_array = plot_array.reshape(canvas.get_width_height()[::-1] + (3,))
115
+ plt.close(fig)
116
+ return plot_array
117
+
118
+
119
+ def normalize_data(data, stats):
120
+ # nomalize to [0,1]
121
+ ndata = (data - stats['min']) / (stats['max'] - stats['min'])
122
+ # normalize to [-1, 1]
123
+ ndata = ndata * 2 - 1
124
+ return ndata
125
+
126
+ def unnormalize_data(ndata, stats):
127
+ ndata = (ndata + 1) / 2
128
+ data = ndata * (stats['max'].to(ndata) - stats['min'].to(ndata)) + stats['min'].to(ndata)
129
+ return data
130
+
131
+ def get_data_path(data_folder: str, f: str, time: int, data_type: str = "image"):
132
+ data_ext = {
133
+ "image": ".jpg",
134
+ "audio": ".wav"
135
+ # add more data types here
136
+ }
137
+ return os.path.join(data_folder, f, f"{str(time)}{data_ext[data_type]}")
138
+
139
+ def yaw_rotmat(yaw: float) -> np.ndarray:
140
+ return np.array(
141
+ [
142
+ [np.cos(yaw), -np.sin(yaw), 0.0],
143
+ [np.sin(yaw), np.cos(yaw), 0.0],
144
+ [0.0, 0.0, 1.0],
145
+ ],
146
+ )
147
+
148
+ def angle_difference(theta1, theta2):
149
+ delta_theta = theta2 - theta1
150
+ delta_theta = delta_theta - 2 * np.pi * np.floor((delta_theta + np.pi) / (2 * np.pi))
151
+ return delta_theta
152
+
153
+ def get_delta_np(actions):
154
+ # append zeros to first action (unbatched)
155
+ ex_actions = np.concatenate((np.zeros((1, actions.shape[1])), actions), axis=0)
156
+ delta = ex_actions[1:] - ex_actions[:-1]
157
+
158
+ return delta
159
+
160
+ def to_local_coords(
161
+ positions: np.ndarray, curr_pos: np.ndarray, curr_yaw: float
162
+ ) -> np.ndarray:
163
+ """
164
+ Convert positions to local coordinates
165
+
166
+ Args:
167
+ positions (np.ndarray): positions to convert
168
+ curr_pos (np.ndarray): current position
169
+ curr_yaw (float): current yaw
170
+ Returns:
171
+ np.ndarray: positions in local coordinates
172
+ """
173
+ rotmat = yaw_rotmat(curr_yaw)
174
+ if positions.shape[-1] == 2:
175
+ rotmat = rotmat[:2, :2]
176
+ elif positions.shape[-1] == 3:
177
+ pass
178
+ else:
179
+ raise ValueError
180
+
181
+ return (positions - curr_pos).dot(rotmat)
182
+
183
+ def calculate_delta_yaw(unnorm_actions):
184
+ x = unnorm_actions[..., 0]
185
+ y = unnorm_actions[..., 1]
186
+
187
+ yaw = torch.atan2(y, x).unsqueeze(-1)
188
+ delta_yaw = torch.cat((torch.zeros(yaw.shape[0], 1, yaw.shape[2]).to(yaw.device), yaw), dim=1)
189
+ delta_yaw = delta_yaw[:, 1:, :] - delta_yaw[:, :-1, :]
190
+
191
+ return delta_yaw
192
+
193
+ def save_planning_pred(dataset_save_output_dir, B, idxs, obs_image, goal_image, preds, deltas, loss, gt_actions, plan_iter=0):
194
+ for batch_idx, idx in enumerate(idxs.flatten()):
195
+ sample_idx = int(idx)
196
+ sample_folder = os.path.join(dataset_save_output_dir, f'id_{sample_idx}')
197
+ os.makedirs(sample_folder, exist_ok=True)
198
+
199
+ preds_save = {
200
+ 'obs_image': obs_image[batch_idx],
201
+ 'goal_image': goal_image[batch_idx],
202
+ 'preds': preds[batch_idx],
203
+ 'deltas': deltas[batch_idx],
204
+ 'loss': loss[batch_idx],
205
+ 'gt_actions': gt_actions[batch_idx],
206
+ }
207
+ preds_file = os.path.join(sample_folder, f"preds_{plan_iter}.pth")
208
+ torch.save(preds_save, preds_file)
209
+
210
+ class CenterCropAR:
211
+ def __init__(self, ar: float = IMAGE_ASPECT_RATIO):
212
+ self.ar = ar
213
+
214
+ def __call__(self, img: Image.Image):
215
+ w, h = img.size
216
+ if w > h:
217
+ img = TF.center_crop(img, (h, int(h * self.ar)))
218
+ else:
219
+ img = TF.center_crop(img, (int(w / self.ar), w))
220
+ return img
221
+
222
+ transform = transforms.Compose([
223
+ CenterCropAR(),
224
+ transforms.Resize((224, 224)),
225
+ transforms.ToTensor(),
226
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
227
+ ])
228
+
229
+ unnormalize = transforms.Normalize(
230
+ mean=[-0.5 / 0.5, -0.5 / 0.5, -0.5 / 0.5],
231
+ std=[1 / 0.5, 1 / 0.5, 1 / 0.5]
232
+ )
models.py ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # GLIDE: https://github.com/openai/glide-text2im
9
+ # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
10
+ # --------------------------------------------------------
11
+ import torch
12
+ import torch.nn as nn
13
+ import numpy as np
14
+ import math
15
+ from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
16
+
17
+
18
+ def modulate(x, shift, scale):
19
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
20
+
21
+
22
+ #################################################################################
23
+ # Embedding Layers for Timesteps and Class Labels #
24
+ #################################################################################
25
+
26
+ class TimestepEmbedder(nn.Module):
27
+ """
28
+ Embeds scalar timesteps into vector representations.
29
+ """
30
+ def __init__(self, hidden_size, frequency_embedding_size=256):
31
+ super().__init__()
32
+ self.mlp = nn.Sequential(
33
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
34
+ nn.SiLU(),
35
+ nn.Linear(hidden_size, hidden_size, bias=True),
36
+ )
37
+ self.frequency_embedding_size = frequency_embedding_size
38
+
39
+ @staticmethod
40
+ def timestep_embedding(t, dim, max_period=10000):
41
+ """
42
+ Create sinusoidal timestep embeddings.
43
+ :param t: a 1-D Tensor of N indices, one per batch element.
44
+ These may be fractional.
45
+ :param dim: the dimension of the output.
46
+ :param max_period: controls the minimum frequency of the embeddings.
47
+ :return: an (N, D) Tensor of positional embeddings.
48
+ """
49
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
50
+ half = dim // 2
51
+ freqs = torch.exp(
52
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
53
+ ).to(device=t.device)
54
+ args = t.float() * freqs[None]
55
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
56
+ if dim % 2:
57
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
58
+ return embedding
59
+
60
+ def forward(self, t):
61
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
62
+ t_emb = self.mlp(t_freq)
63
+ return t_emb
64
+
65
+ class ActionEmbedder(nn.Module):
66
+ """
67
+ Embeds action xy into vector representations.
68
+ """
69
+ def __init__(self, hidden_size, frequency_embedding_size=256):
70
+ super().__init__()
71
+ hsize = hidden_size//3
72
+ self.x_emb = TimestepEmbedder(hsize, frequency_embedding_size)
73
+ self.y_emb = TimestepEmbedder(hsize, frequency_embedding_size)
74
+ self.angle_emb = TimestepEmbedder(hidden_size -2*hsize, frequency_embedding_size)
75
+
76
+ def forward(self, xya):
77
+ return torch.cat([self.x_emb(xya[...,0:1]), self.y_emb(xya[...,1:2]), self.angle_emb(xya[...,2:3])], dim=-1)
78
+
79
+ #################################################################################
80
+ # Core AVCDiT Model #
81
+ #################################################################################
82
+
83
+ class AVCDiTBlock(nn.Module):
84
+ """
85
+ A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning and two modalities.
86
+ """
87
+ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, mode="av", **block_kwargs):
88
+ super().__init__()
89
+ self.mode = mode
90
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
91
+ self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
92
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
93
+ self.norm_cond = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
94
+ self.cttn = nn.MultiheadAttention(hidden_size, num_heads=num_heads, add_bias_kv=True, bias=True, batch_first=True, **block_kwargs)
95
+ self.adaLN_modulation = nn.Sequential(
96
+ nn.SiLU(),
97
+ nn.Linear(hidden_size, 11 * hidden_size, bias=True)
98
+ )
99
+
100
+ self.norm3 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
101
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
102
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
103
+ if self.mode == "av" or self.mode == "v":
104
+ self.mlp_v = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
105
+ if self.mode == "av" or self.mode == "a":
106
+ self.mlp_a = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
107
+
108
+ # def forward(self, x_v, x_a, c, x_v_cond, x_a_cond, mode="av"):
109
+ def forward(self, *args):
110
+ if self.mode == "av":
111
+ x_v, x_a, c, x_v_cond, x_a_cond = args
112
+ shift_msa, scale_msa, gate_msa, shift_ca_xcond, scale_ca_xcond, shift_ca_x, scale_ca_x, gate_ca_x, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(11, dim=1)
113
+ _, v_token_num, _ = x_v.shape
114
+ x = torch.cat([x_v, x_a], dim=1)
115
+ x_cond = torch.cat([x_v_cond, x_a_cond], dim=1)
116
+ x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
117
+ x_cond_norm = modulate(self.norm_cond(x_cond), shift_ca_xcond, scale_ca_xcond)
118
+ x = x + gate_ca_x.unsqueeze(1) * self.cttn(query=modulate(self.norm2(x), shift_ca_x, scale_ca_x), key=x_cond_norm, value=x_cond_norm, need_weights=False)[0]
119
+ x_v = x[:,:v_token_num,:]
120
+ x_a = x[:,v_token_num:,:]
121
+ x_v = x_v + gate_mlp.unsqueeze(1) * self.mlp_v(modulate(self.norm3(x_v), shift_mlp, scale_mlp))
122
+ x_a = x_a + gate_mlp.unsqueeze(1) * self.mlp_a(modulate(self.norm3(x_a), shift_mlp, scale_mlp))
123
+ return x_v, x_a
124
+ elif self.mode == "v":
125
+ x, c, x_cond = args
126
+ shift_msa, scale_msa, gate_msa, shift_ca_xcond, scale_ca_xcond, shift_ca_x, scale_ca_x, gate_ca_x, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(11, dim=1)
127
+ x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
128
+ x_cond_norm = modulate(self.norm_cond(x_cond), shift_ca_xcond, scale_ca_xcond)
129
+ x = x + gate_ca_x.unsqueeze(1) * self.cttn(query=modulate(self.norm2(x), shift_ca_x, scale_ca_x), key=x_cond_norm, value=x_cond_norm, need_weights=False)[0]
130
+ x = x + gate_mlp.unsqueeze(1) * self.mlp_v(modulate(self.norm3(x), shift_mlp, scale_mlp))
131
+ return x
132
+ elif self.mode == "a":
133
+ x, c, x_cond = args
134
+ shift_msa, scale_msa, gate_msa, shift_ca_xcond, scale_ca_xcond, shift_ca_x, scale_ca_x, gate_ca_x, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(11, dim=1)
135
+ x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
136
+ x_cond_norm = modulate(self.norm_cond(x_cond), shift_ca_xcond, scale_ca_xcond)
137
+ x = x + gate_ca_x.unsqueeze(1) * self.cttn(query=modulate(self.norm2(x), shift_ca_x, scale_ca_x), key=x_cond_norm, value=x_cond_norm, need_weights=False)[0]
138
+ x = x + gate_mlp.unsqueeze(1) * self.mlp_a(modulate(self.norm3(x), shift_mlp, scale_mlp))
139
+ return x
140
+
141
+
142
+ class FinalLayer(nn.Module):
143
+ """
144
+ The final layer of DiT.
145
+ """
146
+ def __init__(self, hidden_size, patch_size, out_channels):
147
+ super().__init__()
148
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
149
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
150
+ self.adaLN_modulation = nn.Sequential(
151
+ nn.SiLU(),
152
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True)
153
+ )
154
+
155
+ def forward(self, x, c):
156
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
157
+ x = modulate(self.norm_final(x), shift, scale)
158
+ x = self.linear(x)
159
+ return x
160
+
161
+
162
+ class FinalLayer_audio(nn.Module):
163
+ def __init__(self, hidden_size, out_channels):
164
+ super().__init__()
165
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
166
+ self.linear = nn.Linear(hidden_size, out_channels, bias=True) # no patch²
167
+ self.adaLN_modulation = nn.Sequential(
168
+ nn.SiLU(),
169
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True)
170
+ )
171
+
172
+ def forward(self, x, c):
173
+ # x: (B, N, hidden_size), c: (B, hidden_size)
174
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) # shape (B, hidden_size)
175
+ x = modulate(self.norm_final(x), shift, scale) # apply AdaLN
176
+ x = self.linear(x) # → (B, N, out_channels)
177
+ return x
178
+
179
+
180
+ class AVCDiT(nn.Module):
181
+ """
182
+ Diffusion model with a Transformer backbone.
183
+ """
184
+ def __init__(
185
+ self,
186
+ input_size=32,
187
+ context_size=2,
188
+ patch_size=2,
189
+ in_channels=4,
190
+ hidden_size=1152,
191
+ depth=28,
192
+ num_heads=16,
193
+ mlp_ratio=4.0,
194
+ learn_sigma=True,
195
+ num_patches_a=180,
196
+ mode="av",
197
+ ):
198
+ super().__init__()
199
+ self.mode = mode
200
+ assert (self.mode=="av" or self.mode=="v" or self.mode=="a")
201
+ self.context_size = context_size
202
+ self.learn_sigma = learn_sigma
203
+ self.in_channels = in_channels
204
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
205
+ self.patch_size = patch_size
206
+ self.num_heads = num_heads
207
+
208
+
209
+ if self.mode == "av" or self.mode == "v":
210
+ self.x_embedder_v = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
211
+ num_patches_v = self.x_embedder_v.num_patches
212
+ self.pos_embed_v = nn.Parameter(torch.zeros(self.context_size + 1, num_patches_v, hidden_size), requires_grad=True) # for context and for predicted frame
213
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
214
+ if self.mode == "av" or self.mode == "a":
215
+ self.x_embedder_a = nn.Conv1d(
216
+ in_channels=16,
217
+ out_channels=hidden_size, # [B]
218
+ kernel_size=1,
219
+ stride=1,
220
+ bias=True
221
+ ) #TODO
222
+ self.pos_embed_a_cond = nn.Parameter(torch.zeros(self.context_size, num_patches_a, hidden_size), requires_grad=True)
223
+ self.pos_embed_a_pred = nn.Parameter(torch.zeros(1, num_patches_a+1, hidden_size), requires_grad=True)
224
+ self.final_layer_a = FinalLayer_audio(hidden_size=hidden_size, out_channels=32) # [B]
225
+
226
+ self.t_embedder = TimestepEmbedder(hidden_size)
227
+ self.y_embedder = ActionEmbedder(hidden_size)
228
+
229
+ # self.blocks = nn.ModuleList([AVCDiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)])
230
+ self.blocks = nn.ModuleList([
231
+ AVCDiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, mode=self.mode)
232
+ for _ in range(depth)
233
+ ])
234
+
235
+ self.time_embedder = TimestepEmbedder(hidden_size)
236
+ self.initialize_weights()
237
+
238
+ def initialize_weights(self):
239
+ # Initialize transformer layers:
240
+ def _basic_init(module):
241
+ if isinstance(module, nn.Linear):
242
+ torch.nn.init.xavier_uniform_(module.weight)
243
+ if module.bias is not None:
244
+ nn.init.constant_(module.bias, 0)
245
+ self.apply(_basic_init)
246
+
247
+ # Initialize (and freeze) pos_embed by sin-cos embedding:
248
+ if self.mode == "av" or self.mode == "v":
249
+ nn.init.normal_(self.pos_embed_v, std=0.02)
250
+ if self.mode == "av" or self.mode == "a":
251
+ nn.init.normal_(self.pos_embed_a_pred, std=0.02)
252
+ nn.init.normal_(self.pos_embed_a_cond, std=0.02)
253
+
254
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
255
+ if self.mode == "av" or self.mode == "v":
256
+ w = self.x_embedder_v.proj.weight.data
257
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
258
+ nn.init.constant_(self.x_embedder_v.proj.bias, 0)
259
+
260
+ # Initialize x_embedder_a (Conv1d) like linear
261
+ if self.mode == "av" or self.mode == "a":
262
+ w = self.x_embedder_a.weight.data
263
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
264
+ nn.init.constant_(self.x_embedder_a.bias, 0)
265
+
266
+
267
+ # Initialize action embedding:
268
+ nn.init.normal_(self.y_embedder.x_emb.mlp[0].weight, std=0.02)
269
+ nn.init.normal_(self.y_embedder.x_emb.mlp[2].weight, std=0.02)
270
+
271
+ nn.init.normal_(self.y_embedder.y_emb.mlp[0].weight, std=0.02)
272
+ nn.init.normal_(self.y_embedder.y_emb.mlp[2].weight, std=0.02)
273
+
274
+ nn.init.normal_(self.y_embedder.angle_emb.mlp[0].weight, std=0.02)
275
+ nn.init.normal_(self.y_embedder.angle_emb.mlp[2].weight, std=0.02)
276
+
277
+ # Initialize timestep embedding MLP:
278
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
279
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
280
+
281
+ nn.init.normal_(self.time_embedder.mlp[0].weight, std=0.02)
282
+ nn.init.normal_(self.time_embedder.mlp[2].weight, std=0.02)
283
+
284
+ # Zero-out adaLN modulation layers in DiT blocks:
285
+ for block in self.blocks:
286
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
287
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
288
+
289
+ # Zero-out output layers:
290
+ if self.mode == "av" or self.mode == "v":
291
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
292
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
293
+ nn.init.constant_(self.final_layer.linear.weight, 0)
294
+ nn.init.constant_(self.final_layer.linear.bias, 0)
295
+
296
+ if self.mode == "av" or self.mode == "a":
297
+ nn.init.constant_(self.final_layer_a.adaLN_modulation[-1].weight, 0)
298
+ nn.init.constant_(self.final_layer_a.adaLN_modulation[-1].bias, 0)
299
+ nn.init.constant_(self.final_layer_a.linear.weight, 0)
300
+ nn.init.constant_(self.final_layer_a.linear.bias, 0)
301
+
302
+ def unpatchify(self, x):
303
+ """
304
+ x: (N, T, patch_size**2 * C)
305
+ imgs: (N, H, W, C)
306
+ """
307
+ c = self.out_channels
308
+ p = self.x_embedder_v.patch_size[0]
309
+ h = w = int(x.shape[1] ** 0.5)
310
+ assert h * w == x.shape[1]
311
+
312
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
313
+ x = torch.einsum('nhwpqc->nchpwq', x)
314
+ imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
315
+ return imgs
316
+
317
+ # def forward(self, x_v, x_a, t, y, x_v_cond, x_a_cond, rel_t):
318
+ # def forward(self, *args):
319
+ def forward(self, *args, **kwargs):
320
+ """
321
+ Forward pass of DiT.
322
+ x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
323
+ t: (N,) tensor of diffusion timesteps
324
+ y: (N,) tensor of class labels
325
+ """
326
+ if self.mode == "av":
327
+ if len(args) >= 7:
328
+ x_v, x_a, t, y, x_v_cond, x_a_cond, rel_t = args[:7]
329
+ else:
330
+ assert len(args) == 3, f"mode='v' expects 2 or 5 positional args, got {len(args)}"
331
+ x_v, x_a, t = args
332
+ y = kwargs["y"]
333
+ x_v_cond = kwargs["x_v_cond"]
334
+ x_a_cond = kwargs["x_a_cond"]
335
+ rel_t = kwargs["rel_t"]
336
+
337
+ x_v = self.x_embedder_v(x_v) + self.pos_embed_v[self.context_size:]
338
+ x_v_cond = self.x_embedder_v(x_v_cond.flatten(0, 1)).unflatten(0, (x_v_cond.shape[0], x_v_cond.shape[1])) + self.pos_embed_v[:self.context_size] # (N, T, D), where T = H * W / patch_size ** 2.flatten(1, 2)
339
+ x_v_cond = x_v_cond.flatten(1, 2)
340
+
341
+ x_a = self.x_embedder_a(x_a) # → (B, embed_dim, L')
342
+ x_a = x_a.transpose(1, 2) # → (B, L', embed_dim)
343
+ x_a = x_a + self.pos_embed_a_pred
344
+
345
+ x_a_cond = self.x_embedder_a(x_a_cond.flatten(0, 1)).transpose(1, 2).unflatten(0, (x_a_cond.shape[0], x_a_cond.shape[1])) + self.pos_embed_a_cond
346
+ x_a_cond = x_a_cond.flatten(1, 2)
347
+
348
+ t = self.t_embedder(t[..., None])
349
+ y = self.y_embedder(y)
350
+ time_emb = self.time_embedder(rel_t[..., None])
351
+ c = t + time_emb + y # if training on unlabeled data, dont add y.
352
+
353
+ for block in self.blocks:
354
+ x_v, x_a = block(x_v, x_a, c, x_v_cond, x_a_cond)
355
+ x_v = self.final_layer(x_v, c)
356
+ x_v = self.unpatchify(x_v)
357
+ x_a = self.final_layer_a(x_a, c)
358
+ x_a = x_a.transpose(1, 2)
359
+ return x_v, x_a
360
+ elif self.mode == "v":
361
+ if len(args) >= 5:
362
+ x, t, y, x_cond, rel_t = args[:5]
363
+ else:
364
+ assert len(args) == 2, f"mode='v' expects 2 or 5 positional args, got {len(args)}"
365
+ x, t = args
366
+ y = kwargs["y"]
367
+ x_cond = kwargs["x_cond"]
368
+ rel_t = kwargs["rel_t"]
369
+ x = self.x_embedder_v(x) + self.pos_embed_v[self.context_size:]
370
+ x_cond = self.x_embedder_v(x_cond.flatten(0, 1)).unflatten(0, (x_cond.shape[0], x_cond.shape[1])) + self.pos_embed_v[:self.context_size] # (N, T, D), where T = H * W / patch_size ** 2.flatten(1, 2)
371
+ x_cond = x_cond.flatten(1, 2)
372
+ t = self.t_embedder(t[..., None])
373
+ y = self.y_embedder(y)
374
+ time_emb = self.time_embedder(rel_t[..., None])
375
+ c = t + time_emb + y # if training on unlabeled data, dont add y.
376
+ for block in self.blocks:
377
+ x = block(x, c, x_cond)
378
+ x = self.final_layer(x, c)
379
+ x = self.unpatchify(x)
380
+ return x
381
+ elif self.mode == "a":
382
+ if len(args) >= 5:
383
+ x, t, y, x_cond, rel_t = args[:5]
384
+ else:
385
+ assert len(args) == 2, f"mode='v' expects 2 or 5 positional args, got {len(args)}"
386
+ x, t = args
387
+ y = kwargs["y"]
388
+ x_cond = kwargs["x_cond"]
389
+ rel_t = kwargs["rel_t"]
390
+ x = self.x_embedder_a(x) # → (B, embed_dim, L')
391
+ x = x.transpose(1, 2) # → (B, L', embed_dim)
392
+ x = x + self.pos_embed_a_pred # [REWARD]
393
+ x_cond = self.x_embedder_a(x_cond.flatten(0, 1)).transpose(1, 2).unflatten(0, (x_cond.shape[0], x_cond.shape[1])) + self.pos_embed_a_cond # [REWARD]
394
+ x_cond = x_cond.flatten(1, 2)
395
+ t = self.t_embedder(t[..., None])
396
+ y = self.y_embedder(y)
397
+ time_emb = self.time_embedder(rel_t[..., None])
398
+ c = t + time_emb + y # if training on unlabeled data, dont add y.
399
+ for block in self.blocks:
400
+ x = block(x, c, x_cond)
401
+ x = self.final_layer_a(x, c)
402
+ x = x.transpose(1, 2)
403
+ return x
404
+
405
+ #################################################################################
406
+ # Sine/Cosine Positional Embedding Functions #
407
+ #################################################################################
408
+ # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
409
+
410
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
411
+ """
412
+ grid_size: int of the grid height and width
413
+ return:
414
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
415
+ """
416
+ grid_h = np.arange(grid_size, dtype=np.float32)
417
+ grid_w = np.arange(grid_size, dtype=np.float32)
418
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
419
+ grid = np.stack(grid, axis=0)
420
+
421
+ grid = grid.reshape([2, 1, grid_size, grid_size])
422
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
423
+ if cls_token and extra_tokens > 0:
424
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
425
+ return pos_embed
426
+
427
+
428
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
429
+ assert embed_dim % 2 == 0
430
+
431
+ # use half of dimensions to encode grid_h
432
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
433
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
434
+
435
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
436
+ return emb
437
+
438
+
439
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
440
+ """
441
+ embed_dim: output dimension for each position
442
+ pos: a list of positions to be encoded: size (M,)
443
+ out: (M, D)
444
+ """
445
+ assert embed_dim % 2 == 0
446
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
447
+ omega /= embed_dim / 2.
448
+ omega = 1. / 10000**omega # (D/2,)
449
+
450
+ pos = pos.reshape(-1) # (M,)
451
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
452
+
453
+ emb_sin = np.sin(out) # (M, D/2)
454
+ emb_cos = np.cos(out) # (M, D/2)
455
+
456
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
457
+ return emb
458
+
459
+
460
+ #################################################################################
461
+ # AVCDiT Configs #
462
+ #################################################################################
463
+
464
+ def AVCDiT_XL_2(**kwargs):
465
+ return AVCDiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
466
+
467
+ def AVCDiT_L_2(**kwargs):
468
+ return AVCDiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs)
469
+
470
+ def AVCDiT_B_2(**kwargs):
471
+ return AVCDiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs)
472
+
473
+ def AVCDiT_S_2(**kwargs):
474
+ return AVCDiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
475
+
476
+
477
+ AVCDiT_models = {
478
+ 'AVCDiT-XL/2': AVCDiT_XL_2,
479
+ 'AVCDiT-L/2': AVCDiT_L_2,
480
+ 'AVCDiT-B/2': AVCDiT_B_2,
481
+ 'AVCDiT-S/2': AVCDiT_S_2
482
+ }
soundstream.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn.utils import weight_norm
5
+
6
+ from vector_quantize_pytorch import ResidualVQ
7
+
8
+ class CausalConv1d(nn.Conv1d):
9
+ def __init__(self, *args, **kwargs):
10
+ super().__init__(*args, **kwargs)
11
+ self.causal_padding = self.dilation[0] * (self.kernel_size[0] - 1)
12
+
13
+ def forward(self, x):
14
+ return self._conv_forward(F.pad(x, [self.causal_padding, 0]), self.weight, self.bias)
15
+
16
+
17
+ class CausalConvTranspose1d(nn.ConvTranspose1d):
18
+ def __init__(self, *args, **kwargs):
19
+ super().__init__(*args, **kwargs)
20
+ self.causal_padding = self.dilation[0] * (self.kernel_size[0] - 1) + self.output_padding[0] + 1 - self.stride[0]
21
+
22
+ def forward(self, x, output_size=None):
23
+ if self.padding_mode != 'zeros':
24
+ raise ValueError('Only `zeros` padding mode is supported for ConvTranspose1d')
25
+
26
+ assert isinstance(self.padding, tuple)
27
+ output_padding = self._output_padding(
28
+ x, output_size, self.stride, self.padding, self.kernel_size, self.dilation)
29
+ return F.conv_transpose1d(
30
+ x, self.weight, self.bias, self.stride, self.padding,
31
+ output_padding, self.groups, self.dilation)[...,:-self.causal_padding]
32
+
33
+
34
+ class ResidualUnit(nn.Module):
35
+ def __init__(self, in_channels, out_channels, dilation):
36
+ super().__init__()
37
+
38
+ self.dilation = dilation
39
+
40
+ self.layers = nn.Sequential(
41
+ CausalConv1d(in_channels=in_channels, out_channels=out_channels,
42
+ kernel_size=7, dilation=dilation),
43
+ nn.ELU(),
44
+ nn.Conv1d(in_channels=in_channels, out_channels=out_channels,
45
+ kernel_size=1)
46
+ )
47
+
48
+ def forward(self, x):
49
+ return x + self.layers(x)
50
+
51
+
52
+ class EncoderBlock(nn.Module):
53
+ def __init__(self, out_channels, stride):
54
+ super().__init__()
55
+
56
+ self.layers = nn.Sequential(
57
+ ResidualUnit(in_channels=out_channels//2,
58
+ out_channels=out_channels//2, dilation=1),
59
+ nn.ELU(),
60
+ ResidualUnit(in_channels=out_channels//2,
61
+ out_channels=out_channels//2, dilation=3),
62
+ nn.ELU(),
63
+ ResidualUnit(in_channels=out_channels//2,
64
+ out_channels=out_channels//2, dilation=9),
65
+ nn.ELU(),
66
+ CausalConv1d(in_channels=out_channels//2, out_channels=out_channels,
67
+ kernel_size=2*stride, stride=stride)
68
+ )
69
+
70
+ def forward(self, x):
71
+ return self.layers(x)
72
+
73
+
74
+ class DecoderBlock(nn.Module):
75
+ def __init__(self, out_channels, stride):
76
+ super().__init__()
77
+
78
+ self.layers = nn.Sequential(
79
+ CausalConvTranspose1d(in_channels=2*out_channels,
80
+ out_channels=out_channels,
81
+ kernel_size=2*stride, stride=stride),
82
+ nn.ELU(),
83
+ ResidualUnit(in_channels=out_channels, out_channels=out_channels,
84
+ dilation=1),
85
+ nn.ELU(),
86
+ ResidualUnit(in_channels=out_channels, out_channels=out_channels,
87
+ dilation=3),
88
+ nn.ELU(),
89
+ ResidualUnit(in_channels=out_channels, out_channels=out_channels,
90
+ dilation=9),
91
+
92
+ )
93
+
94
+ def forward(self, x):
95
+ return self.layers(x)
96
+
97
+
98
+ class Encoder(nn.Module):
99
+ def __init__(self, C, D):
100
+ super().__init__()
101
+
102
+ self.layers = nn.Sequential(
103
+ CausalConv1d(in_channels=2, out_channels=C, kernel_size=7),
104
+ nn.ELU(),
105
+ EncoderBlock(out_channels=2*C, stride=2),
106
+ nn.ELU(),
107
+ EncoderBlock(out_channels=4*C, stride=4),
108
+ nn.ELU(),
109
+ EncoderBlock(out_channels=8*C, stride=5),
110
+ nn.ELU(),
111
+ # EncoderBlock(out_channels=16*C, stride=8),
112
+ # nn.ELU(),
113
+ # CausalConv1d(in_channels=16*C, out_channels=D, kernel_size=3)
114
+ CausalConv1d(in_channels=8*C, out_channels=D, kernel_size=3)
115
+ )
116
+
117
+ def forward(self, x):
118
+ return self.layers(x)
119
+
120
+
121
+ class Decoder(nn.Module):
122
+ def __init__(self, C, D):
123
+ super().__init__()
124
+
125
+ self.layers = nn.Sequential(
126
+ CausalConv1d(in_channels=D, out_channels=8*C, kernel_size=7),
127
+ # CausalConv1d(in_channels=D, out_channels=16*C, kernel_size=7),
128
+ # nn.ELU(),
129
+ # DecoderBlock(out_channels=8*C, stride=8),
130
+ nn.ELU(),
131
+ DecoderBlock(out_channels=4*C, stride=5),
132
+ nn.ELU(),
133
+ DecoderBlock(out_channels=2*C, stride=4),
134
+ nn.ELU(),
135
+ DecoderBlock(out_channels=C, stride=2),
136
+ nn.ELU(),
137
+ CausalConv1d(in_channels=C, out_channels=2, kernel_size=7)
138
+ )
139
+
140
+ def forward(self, x):
141
+ return self.layers(x)
142
+
143
+
144
+ class SoundStream(nn.Module):
145
+ def __init__(self, C, D, n_q, codebook_size):
146
+ super().__init__()
147
+
148
+ self.encoder = Encoder(C=C, D=D)
149
+ self.quantizer = ResidualVQ(
150
+ num_quantizers=n_q, dim=D, codebook_size=codebook_size,
151
+ kmeans_init=True, kmeans_iters=100, threshold_ema_dead_code=2
152
+ )
153
+ self.decoder = Decoder(C=C, D=D)
154
+
155
+ @staticmethod
156
+ def pad_to_multiple(x, multiple):
157
+ """
158
+ x: [B, C, T]
159
+ multiple: int, e.g., 320
160
+ return: padded_x, original_length
161
+ """
162
+ B, C, T = x.shape
163
+ target_len = ((T + multiple - 1) // multiple) * multiple
164
+ pad_len = target_len - T
165
+ padded_x = F.pad(x, (0, pad_len), mode='reflect')
166
+ return padded_x, T
167
+
168
+ @staticmethod
169
+ def crop_to_length(x, original_length):
170
+ return x[..., :original_length]
171
+
172
+ def forward(self, x):
173
+ e = self.encoder(x) # [B, D, T']
174
+ e = e.permute(0, 2, 1) # → [B, T', D]
175
+ quantized, _, _ = self.quantizer(e)
176
+ quantized = quantized.permute(0, 2, 1) # → [B, D, T']
177
+ o = self.decoder(quantized) # → [B, 2, T_padded]
178
+ return o
train_avwm_stage1.py ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from inference_avwm import model_forward_wrapper_v
2
+ import torch
3
+ # the first flag below was False when we tested this script but True makes A100 training a lot faster:
4
+ torch.backends.cuda.matmul.allow_tf32 = True
5
+ torch.backends.cudnn.allow_tf32 = True
6
+
7
+ import matplotlib
8
+ matplotlib.use('Agg')
9
+ from collections import OrderedDict
10
+ from copy import deepcopy
11
+ from time import time
12
+ import argparse
13
+ import logging
14
+ import os
15
+ import matplotlib.pyplot as plt
16
+ import yaml
17
+
18
+
19
+ import torch.distributed as dist
20
+ from torch.nn.parallel import DistributedDataParallel as DDP
21
+ from torch.utils.data import DataLoader, ConcatDataset
22
+ from torch.utils.data.distributed import DistributedSampler
23
+ from diffusers.models import AutoencoderKL
24
+
25
+ from distributed import init_distributed
26
+ from models import AVCDiT_models
27
+ from diffusion import create_diffusion
28
+ from datasets import TrainingDataset
29
+ from misc import transform
30
+
31
+ #################################################################################
32
+ # Training Helper Functions #
33
+ #################################################################################
34
+
35
+
36
+ def load_checkpoint_if_available(model, ema, opt, scaler, config, device, logger, args):
37
+ start_epoch = 0
38
+ train_steps = 0
39
+ latest_path = os.path.join(config['results_dir'], config['run_name'], "checkpoints", "latest.pth.tar")
40
+ if os.path.isfile(latest_path) or config.get('from_checkpoint', 0):
41
+ latest_path = latest_path if os.path.isfile(latest_path) else config.get('from_checkpoint', 0)
42
+ print("Loading model from ", latest_path)
43
+ checkpoint = torch.load(latest_path, map_location=f"cuda:{device}", weights_only=False)
44
+
45
+ ema_ckp = {k.replace('_orig_mod.', ''): v for k, v in checkpoint["ema"].items()}
46
+ remapped = {}
47
+ for k, v in ema_ckp.items():
48
+ new_k = k
49
+ # 1) pos_embed -> pos_embed_v
50
+ if k.startswith("pos_embed"):
51
+ new_k = k.replace("pos_embed", "pos_embed_v", 1)
52
+ # 2) x_embedder. -> x_embedder_v.
53
+ if new_k.startswith("x_embedder."):
54
+ new_k = new_k.replace("x_embedder.", "x_embedder_v.", 1)
55
+ # 3) blocks.*.mlp.*: .mlp. -> .mlp_v.
56
+ if new_k.startswith("blocks.") and ".mlp." in new_k:
57
+ new_k = new_k.replace(".mlp.", ".mlp_v.", 1)
58
+ remapped[new_k] = v
59
+ ema_ckp = remapped
60
+ model.load_state_dict(ema_ckp, strict=True)
61
+ print("Model weights loaded.")
62
+ ema.load_state_dict(ema_ckp, strict=True)
63
+ print("EMA weights loaded.")
64
+
65
+ if args.restart_from_checkpoint:
66
+ logger.info("Restarting training: epoch and step counters set to 0.")
67
+ else:
68
+ if "opt" in checkpoint:
69
+ opt_ckp = {k.replace('_orig_mod.', ''): v for k, v in checkpoint["opt"].items()}
70
+ opt.load_state_dict(opt_ckp)
71
+ print("Optimizer state loaded.")
72
+ if "scaler" in checkpoint and scaler is not None:
73
+ scaler.load_state_dict(checkpoint["scaler"])
74
+ print("GradScaler state loaded.")
75
+ if "epoch" in checkpoint:
76
+ start_epoch = checkpoint["epoch"] + 1
77
+ if "train_steps" in checkpoint:
78
+ train_steps = checkpoint["train_steps"]
79
+ logger.info(f"Resuming from epoch {start_epoch}, step {train_steps}")
80
+
81
+ return start_epoch, train_steps
82
+
83
+
84
+ @torch.no_grad()
85
+ def update_ema(ema_model, model, decay=0.9999):
86
+ """
87
+ Step the EMA model towards the current model.
88
+ """
89
+ ema_params = OrderedDict(ema_model.named_parameters())
90
+ model_params = OrderedDict(model.named_parameters())
91
+
92
+ for name, param in model_params.items():
93
+ name = name.replace('_orig_mod.', '')
94
+ ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
95
+
96
+
97
+ def requires_grad(model, flag=True):
98
+ """
99
+ Set requires_grad flag for all parameters in a model.
100
+ """
101
+ for p in model.parameters():
102
+ p.requires_grad = flag
103
+
104
+
105
+ def cleanup():
106
+ """
107
+ End DDP training.
108
+ """
109
+ dist.destroy_process_group()
110
+
111
+
112
+ def create_logger(logging_dir):
113
+ """
114
+ Create a logger that writes to a log file and stdout.
115
+ """
116
+ if dist.get_rank() == 0: # real logger
117
+ logging.basicConfig(
118
+ level=logging.INFO,
119
+ format='[\033[34m%(asctime)s\033[0m] %(message)s',
120
+ datefmt='%Y-%m-%d %H:%M:%S',
121
+ handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
122
+ )
123
+ logger = logging.getLogger(__name__)
124
+ else: # dummy logger (does nothing)
125
+ logger = logging.getLogger(__name__)
126
+ logger.addHandler(logging.NullHandler())
127
+ return logger
128
+
129
+ #################################################################################
130
+ # Training Loop #
131
+ #################################################################################
132
+
133
+ def main(args):
134
+ """
135
+ Trains a new AVCDiT model.
136
+ """
137
+ assert torch.cuda.is_available(), "Training currently requires at least one GPU."
138
+
139
+ # Setup DDP:
140
+ _, rank, device, _ = init_distributed()
141
+ # rank = dist.get_rank()
142
+ seed = args.global_seed * dist.get_world_size() + rank
143
+ torch.manual_seed(seed)
144
+ print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
145
+ with open("config/eval_config.yaml", "r") as f:
146
+ default_config = yaml.safe_load(f)
147
+ config = default_config
148
+
149
+ with open(args.config, "r") as f:
150
+ user_config = yaml.safe_load(f)
151
+ config.update(user_config)
152
+
153
+ # Setup an experiment folder:
154
+ os.makedirs(config['results_dir'], exist_ok=True) # Make results folder (holds all experiment subfolders)
155
+ experiment_dir = f"{config['results_dir']}/{config['run_name']}" # Create an experiment folder
156
+ checkpoint_dir = f"{experiment_dir}/checkpoints" # Stores saved model checkpoints
157
+ if rank == 0:
158
+ os.makedirs(checkpoint_dir, exist_ok=True)
159
+ logger = create_logger(experiment_dir)
160
+ logger.info(f"Experiment directory created at {experiment_dir}")
161
+ else:
162
+ logger = create_logger(None)
163
+
164
+ # Create model:
165
+ tokenizer = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-ema").to(device)
166
+ latent_size = config['image_size'] // 8
167
+
168
+ assert config['image_size'] % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)."
169
+ num_cond = config['context_size']
170
+ model = AVCDiT_models[config['model']](context_size=num_cond, input_size=latent_size, in_channels=4, mode="v").to(device)
171
+
172
+ ema = deepcopy(model).to(device) # Create an EMA of the model for use after training
173
+ requires_grad(ema, False)
174
+
175
+ # Setup optimizer (we used default Adam betas=(0.9, 0.999) and a constant learning rate of 1e-4 in our paper):
176
+ lr = float(config.get('lr', 1e-4))
177
+ opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0)
178
+
179
+
180
+ bfloat_enable = bool(hasattr(args, 'bfloat16') and args.bfloat16)
181
+ if bfloat_enable:
182
+ scaler = torch.amp.GradScaler()
183
+
184
+ # load existing checkpoint
185
+ # latest_path = os.path.join(checkpoint_dir, "latest.pth.tar")
186
+ # === Load checkpoint or start from a pretrained one ===
187
+ start_epoch, train_steps = load_checkpoint_if_available(
188
+ model, ema, opt, scaler if bfloat_enable else None, config, device, logger, args
189
+ )
190
+
191
+ # ~40% speedup but might leads to worse performance depending on pytorch version
192
+ if args.torch_compile:
193
+ model = torch.compile(model)
194
+ model = DDP(model, device_ids=[device])
195
+ diffusion = create_diffusion(timestep_respacing="") # default: 1000 steps, linear noise schedule
196
+ # ,predict_xstart=True
197
+ logger.info(f"AVCDiT Parameters: {sum(p.numel() for p in model.parameters()):,}")
198
+
199
+ train_dataset = []
200
+ test_dataset = []
201
+
202
+ for dataset_name in config["datasets"]:
203
+ data_config = config["datasets"][dataset_name]
204
+
205
+ for data_split_type in ["train", "test"]:
206
+ if data_split_type in data_config:
207
+ goals_per_obs = int(data_config["goals_per_obs"])
208
+ if data_split_type == 'test':
209
+ goals_per_obs = 4 # standardize testing
210
+
211
+ if "distance" in data_config:
212
+ min_dist_cat=data_config["distance"]["min_dist_cat"]
213
+ max_dist_cat=data_config["distance"]["max_dist_cat"]
214
+ else:
215
+ min_dist_cat=config["distance"]["min_dist_cat"]
216
+ max_dist_cat=config["distance"]["max_dist_cat"]
217
+
218
+ if "len_traj_pred" in data_config:
219
+ len_traj_pred=data_config["len_traj_pred"]
220
+ else:
221
+ len_traj_pred=config["len_traj_pred"]
222
+
223
+ dataset = TrainingDataset(
224
+ data_folder=data_config["data_folder"],
225
+ data_split_folder=data_config[data_split_type],
226
+ dataset_name=dataset_name,
227
+ image_size=config["image_size"],
228
+ min_dist_cat=min_dist_cat,
229
+ max_dist_cat=max_dist_cat,
230
+ len_traj_pred=len_traj_pred,
231
+ context_size=config["context_size"],
232
+ normalize=config["normalize"],
233
+ goals_per_obs=goals_per_obs,
234
+ transform=transform,
235
+ predefined_index=None,
236
+ traj_stride=1,
237
+ evaluate=(data_split_type=="test")
238
+ )
239
+ if data_split_type == "train":
240
+ train_dataset.append(dataset)
241
+ else:
242
+ test_dataset.append(dataset)
243
+ print(f"Dataset: {dataset_name} ({data_split_type}), size: {len(dataset)}")
244
+
245
+ # combine all the datasets from different robots
246
+ print(f"Combining {len(train_dataset)} datasets.")
247
+ train_dataset = ConcatDataset(train_dataset)
248
+ test_dataset = ConcatDataset(test_dataset)
249
+
250
+ sampler = DistributedSampler(
251
+ train_dataset,
252
+ num_replicas=dist.get_world_size(),
253
+ rank=rank,
254
+ shuffle=True,
255
+ seed=args.global_seed
256
+ )
257
+ loader = DataLoader(
258
+ train_dataset,
259
+ batch_size=config['batch_size'],
260
+ shuffle=False,
261
+ sampler=sampler,
262
+ num_workers=config['num_workers'],
263
+ pin_memory=True,
264
+ drop_last=True,
265
+ persistent_workers=True
266
+ )
267
+ logger.info(f"Dataset contains {len(train_dataset):,} images")
268
+
269
+ # Prepare models for training:
270
+ model.train() # important! This enables embedding dropout for classifier-free guidance
271
+ ema.eval() # EMA model should always be in eval mode
272
+
273
+ # Variables for monitoring/logging purposes:
274
+ log_steps = 0
275
+ running_loss = 0
276
+ start_time = time()
277
+
278
+ logger.info(f"Training for {args.epochs} epochs...")
279
+ for epoch in range(start_epoch, args.epochs):
280
+ sampler.set_epoch(epoch)
281
+ steps_per_epoch = len(loader)
282
+ if rank == 0:
283
+ logger.info(f"Epoch {epoch} contains {steps_per_epoch} steps.")
284
+ logger.info(f"Beginning epoch {epoch}...")
285
+
286
+ for x, _, y, _, rel_t in loader:
287
+ x = x.to(device, non_blocking=True)
288
+ y = y.to(device, non_blocking=True)
289
+ rel_t = rel_t.to(device, non_blocking=True)
290
+
291
+ with torch.amp.autocast('cuda', enabled=bfloat_enable, dtype=torch.bfloat16):
292
+ with torch.no_grad():
293
+ # Map input images to latent space + normalize latents:
294
+ B, T = x.shape[:2]
295
+ x = x.flatten(0,1)
296
+ x = tokenizer.encode(x).latent_dist.sample().mul_(0.18215)
297
+ x = x.unflatten(0, (B, T))
298
+
299
+ num_goals = T - num_cond
300
+ x_start = x[:, num_cond:].flatten(0, 1)
301
+ x_cond = x[:, :num_cond].unsqueeze(1).expand(B, num_goals, num_cond, x.shape[2], x.shape[3], x.shape[4]).flatten(0, 1)
302
+ y = y.flatten(0, 1)
303
+ rel_t = rel_t.flatten(0, 1)
304
+
305
+ t = torch.randint(0, diffusion.num_timesteps, (x_start.shape[0],), device=device)
306
+ model_kwargs = dict(y=y, x_cond=x_cond, rel_t=rel_t)
307
+ loss_dict = diffusion.training_losses(model, x_start, t, model_kwargs)
308
+ loss = loss_dict["loss"].mean()
309
+
310
+ if not bfloat_enable:
311
+ opt.zero_grad()
312
+ loss.backward()
313
+ opt.step()
314
+ else:
315
+ scaler.scale(loss).backward()
316
+ if config.get('grad_clip_val', 0) > 0:
317
+ scaler.unscale_(opt)
318
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config['grad_clip_val'])
319
+ scaler.step(opt)
320
+ scaler.update()
321
+
322
+ update_ema(ema, model.module)
323
+
324
+ # Log loss values:
325
+ running_loss += loss.detach().item()
326
+ log_steps += 1
327
+ train_steps += 1
328
+ if train_steps % args.log_every == 0:
329
+ # Measure training speed:
330
+ torch.cuda.synchronize()
331
+ end_time = time()
332
+ steps_per_sec = log_steps / (end_time - start_time)
333
+ samples_per_sec = dist.get_world_size()*x_cond.shape[0]*steps_per_sec
334
+ # Reduce loss history over all processes:
335
+ avg_loss = torch.tensor(running_loss / log_steps, device=device)
336
+ dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM)
337
+ avg_loss = avg_loss.item() / dist.get_world_size()
338
+ total_steps = len(loader) * args.epochs
339
+ progress_pct = train_steps / total_steps * 100
340
+
341
+ remaining_steps = total_steps - train_steps
342
+ eta_seconds = remaining_steps / steps_per_sec if steps_per_sec > 0 else 0
343
+ eta_hours = eta_seconds / 3600
344
+
345
+ logger.info(f"(step={train_steps:07d}) Train Loss: {avg_loss:.4f}, Train Steps/Sec: {steps_per_sec:.2f}, Samples/Sec: {samples_per_sec:.2f}")
346
+ logger.info(f"Progress: {progress_pct:.2f}% | ETA: {eta_hours:.1f}h")
347
+ # Reset monitoring variables:
348
+ running_loss = 0
349
+ log_steps = 0
350
+ start_time = time()
351
+
352
+ # Save DiT checkpoint:
353
+ if train_steps % args.ckpt_every == 0 and train_steps > 0:
354
+ if rank == 0:
355
+ checkpoint = {
356
+ "model": model.module.state_dict(),
357
+ "ema": ema.state_dict(),
358
+ "opt": opt.state_dict(),
359
+ "args": args,
360
+ "epoch": epoch,
361
+ "train_steps": train_steps
362
+ }
363
+ if bfloat_enable:
364
+ checkpoint.update({"scaler": scaler.state_dict()})
365
+ checkpoint_path = f"{checkpoint_dir}/latest.pth.tar"
366
+ torch.save(checkpoint, checkpoint_path)
367
+ if train_steps % (10*args.ckpt_every) == 0 and train_steps > 0:
368
+ checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pth.tar"
369
+ torch.save(checkpoint, checkpoint_path)
370
+ logger.info(f"Saved checkpoint to {checkpoint_path}")
371
+
372
+ if train_steps % args.eval_every == 0 and train_steps > 0:
373
+ eval_start_time = time()
374
+ # validation / test set evaluation
375
+ save_dir = os.path.join(experiment_dir, str(train_steps))
376
+ sim_score_val = evaluate(ema, tokenizer, diffusion, test_dataset, rank, config["batch_size"], config["num_workers"], latent_size, device, save_dir, args.global_seed, bfloat_enable, num_cond)
377
+ dist.barrier()
378
+ eval_end_time = time()
379
+ eval_time = eval_end_time - eval_start_time
380
+ # logger.info(f"(step={train_steps:07d}) Val Perceptual Loss: {sim_score_val:.4f}, Train Perceptual Loss: {sim_score_train:.4f}, Eval Time: {eval_time:.2f}")
381
+ logger.info(f"(step={train_steps:07d}) Val Perceptual Loss: {sim_score_val:.4f}, Eval Time: {eval_time:.2f}")
382
+
383
+ model.eval()
384
+ logger.info("Done!")
385
+ cleanup()
386
+
387
+
388
+ @torch.no_grad
389
+ def evaluate(model, vae, diffusion, test_dataloaders, rank, batch_size, num_workers, latent_size, device, save_dir, seed, bfloat_enable, num_cond):
390
+ sampler = DistributedSampler(
391
+ test_dataloaders,
392
+ num_replicas=dist.get_world_size(),
393
+ rank=rank,
394
+ shuffle=True,
395
+ seed=seed
396
+ )
397
+ loader = DataLoader(
398
+ test_dataloaders,
399
+ batch_size=batch_size,
400
+ shuffle=False,
401
+ sampler=sampler,
402
+ num_workers=num_workers,
403
+ pin_memory=True,
404
+ drop_last=True
405
+ )
406
+ from dreamsim import dreamsim
407
+ eval_model, _ = dreamsim(pretrained=True)
408
+ score = torch.tensor(0.).to(device)
409
+ n_samples = torch.tensor(0).to(device)
410
+
411
+ # Run for 1 step
412
+ for x, _, y, _, rel_t, _ in loader:
413
+ x = x.to(device)
414
+ y = y.to(device)
415
+ rel_t = rel_t.to(device).flatten(0, 1)
416
+ with torch.amp.autocast('cuda', enabled=True, dtype=torch.bfloat16):
417
+ B, T = x.shape[:2]
418
+ num_goals = T - num_cond
419
+ samples = model_forward_wrapper_v((model, diffusion, vae), x, y, num_timesteps=None, latent_size=latent_size, device=device, num_cond=num_cond, num_goals=num_goals, rel_t=rel_t)
420
+ x_start_pixels = x[:, num_cond:].flatten(0, 1)
421
+ x_cond_pixels = x[:, :num_cond].unsqueeze(1).expand(B, num_goals, num_cond, x.shape[2], x.shape[3], x.shape[4]).flatten(0, 1)
422
+ samples = samples * 0.5 + 0.5
423
+ x_start_pixels = x_start_pixels * 0.5 + 0.5
424
+ x_cond_pixels = x_cond_pixels * 0.5 + 0.5
425
+ res = eval_model(x_start_pixels, samples)
426
+ score += res.sum()
427
+ n_samples += len(res)
428
+ break
429
+
430
+ if rank == 0:
431
+ os.makedirs(save_dir, exist_ok=True)
432
+ for i in range(min(samples.shape[0], 10)):
433
+ _, ax = plt.subplots(1,3,dpi=256)
434
+ ax[0].imshow((x_cond_pixels[i, -1].permute(1,2,0).cpu().numpy()*255).astype('uint8'))
435
+ ax[1].imshow((x_start_pixels[i].permute(1,2,0).cpu().numpy()*255).astype('uint8'))
436
+ ax[2].imshow((samples[i].permute(1,2,0).cpu().float().numpy()*255).astype('uint8'))
437
+ plt.savefig(f'{save_dir}/{i}.png')
438
+ plt.close()
439
+
440
+ dist.all_reduce(score)
441
+ dist.all_reduce(n_samples)
442
+ sim_score = score/n_samples
443
+ return sim_score
444
+
445
+
446
+ def get_args_parser():
447
+ parser = argparse.ArgumentParser()
448
+ parser.add_argument("--config", type=str, required=True)
449
+ parser.add_argument("--epochs", type=int, default=300)
450
+ # parser.add_argument("--global-batch-size", type=int, default=256)
451
+ parser.add_argument("--global-seed", type=int, default=0)
452
+ parser.add_argument("--log-every", type=int, default=100)
453
+ parser.add_argument("--ckpt-every", type=int, default=2000)
454
+ parser.add_argument("--eval-every", type=int, default=5000)
455
+ parser.add_argument("--bfloat16", type=int, default=1)
456
+ parser.add_argument("--torch-compile", type=int, default=1)
457
+ parser.add_argument("--restart-from-checkpoint", type=int, default=0,
458
+ help="If 1, only load model weights and reset epoch/step to zero (cold start)")
459
+ return parser
460
+
461
+ if __name__ == "__main__":
462
+ args = get_args_parser().parse_args()
463
+ main(args)
train_avwm_stage2.py ADDED
@@ -0,0 +1,514 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # NoMaD, GNM, ViNT: https://github.com/robodhruv/visualnav-transformer
9
+ # --------------------------------------------------------
10
+
11
+ from inference_avwm import model_forward_wrapper_a
12
+ import torch
13
+ # the first flag below was False when we tested this script but True makes A100 training a lot faster:
14
+ torch.backends.cuda.matmul.allow_tf32 = True
15
+ torch.backends.cudnn.allow_tf32 = True
16
+
17
+ import matplotlib
18
+ matplotlib.use('Agg')
19
+ from collections import OrderedDict
20
+ from copy import deepcopy
21
+ from time import time
22
+ import argparse
23
+ import logging
24
+ import os
25
+ import matplotlib.pyplot as plt
26
+ import yaml
27
+
28
+
29
+ import torch.distributed as dist
30
+ from torch.nn.parallel import DistributedDataParallel as DDP
31
+ from torch.utils.data import DataLoader, ConcatDataset
32
+ from torch.utils.data.distributed import DistributedSampler
33
+ from diffusers.models import AutoencoderKL
34
+
35
+ from distributed import init_distributed
36
+ from models import AVCDiT_models
37
+ from diffusion import create_diffusion
38
+ from datasets import TrainingDataset
39
+ from misc import transform
40
+ from soundstream import SoundStream
41
+ # from audiovae import BinauralSeqTokenCodec
42
+ import torchaudio
43
+ from eval_audio import build_mel_transform, mel_cosine_stereo, drms_avg_db_stereo, save_ref_hat_spectrogram_panel
44
+
45
+
46
+ def load_checkpoint_if_available(model, ema, opt, scaler, config, device, logger, args):
47
+ start_epoch = 0
48
+ train_steps = 0
49
+ latest_path = os.path.join(config['results_dir'], config['run_name'], "checkpoints", "latest.pth.tar")
50
+ if os.path.isfile(latest_path) or config.get('from_checkpoint', 0):
51
+ latest_path = latest_path if os.path.isfile(latest_path) else config.get('from_checkpoint', 0)
52
+ print("Loading model from ", latest_path)
53
+ checkpoint = torch.load(latest_path, map_location=f"cuda:{device}", weights_only=False)
54
+ ema_ckp = {k.replace('_orig_mod.', ''): v for k, v in checkpoint["ema"].items()}
55
+ remapped = {}
56
+ for k, v in ema_ckp.items():
57
+ new_k = k
58
+ if new_k.startswith("blocks.") and ".mlp_v." in new_k:
59
+ new_k = new_k.replace(".mlp_v.", ".mlp_a.", 1)
60
+ remapped[new_k] = v
61
+ ema_ckp = remapped
62
+ model_state = model.state_dict()
63
+ load_info = model.load_state_dict(ema_ckp, strict=False)
64
+
65
+ print("Model weights loaded.")
66
+ ema.load_state_dict(ema_ckp, strict=False)
67
+ print("EMA weights loaded.")
68
+ if args.restart_from_checkpoint:
69
+ logger.info("Restarting training: epoch and step counters set to 0.")
70
+ else:
71
+ try:
72
+ if "opt" in checkpoint:
73
+ opt_ckp = {k.replace('_orig_mod.', ''): v for k, v in checkpoint["opt"].items()}
74
+ opt.load_state_dict(opt_ckp)
75
+ print("Optimizer state loaded.")
76
+ if "scaler" in checkpoint and scaler is not None:
77
+ scaler.load_state_dict(checkpoint["scaler"])
78
+ print("GradScaler state loaded.")
79
+ except ValueError as e:
80
+ print(f"[WARN] Skip loading opt and scaler")
81
+ if "epoch" in checkpoint:
82
+ start_epoch = checkpoint["epoch"] + 1
83
+ if "train_steps" in checkpoint:
84
+ train_steps = checkpoint["train_steps"]
85
+ logger.info(f"Resuming from epoch {start_epoch}, step {train_steps}")
86
+
87
+ return start_epoch, train_steps
88
+
89
+
90
+ @torch.no_grad()
91
+ def update_ema(ema_model, model, decay=0.9999):
92
+ """
93
+ Step the EMA model towards the current model.
94
+ """
95
+ ema_params = OrderedDict(ema_model.named_parameters())
96
+ model_params = OrderedDict(model.named_parameters())
97
+
98
+ for name, param in model_params.items():
99
+ name = name.replace('_orig_mod.', '')
100
+ ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
101
+
102
+
103
+ def requires_grad(model, flag=True):
104
+ """
105
+ Set requires_grad flag for all parameters in a model.
106
+ """
107
+ for p in model.parameters():
108
+ p.requires_grad = flag
109
+
110
+
111
+ def cleanup():
112
+ """
113
+ End DDP training.
114
+ """
115
+ dist.destroy_process_group()
116
+
117
+
118
+ def create_logger(logging_dir):
119
+ """
120
+ Create a logger that writes to a log file and stdout.
121
+ """
122
+ if dist.get_rank() == 0: # real logger
123
+ logging.basicConfig(
124
+ level=logging.INFO,
125
+ format='[\033[34m%(asctime)s\033[0m] %(message)s',
126
+ datefmt='%Y-%m-%d %H:%M:%S',
127
+ handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
128
+ )
129
+ logger = logging.getLogger(__name__)
130
+ else: # dummy logger (does nothing)
131
+ logger = logging.getLogger(__name__)
132
+ logger.addHandler(logging.NullHandler())
133
+ return logger
134
+
135
+ #################################################################################
136
+ # Training Loop #
137
+ #################################################################################
138
+
139
+ def main(args):
140
+ """
141
+ Trains a new AVCDiT model.
142
+ """
143
+ assert torch.cuda.is_available(), "Training currently requires at least one GPU."
144
+
145
+ # Setup DDP:
146
+ _, rank, device, _ = init_distributed()
147
+ # rank = dist.get_rank()
148
+ seed = args.global_seed * dist.get_world_size() + rank
149
+ torch.manual_seed(seed)
150
+ print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
151
+ with open("config/eval_config.yaml", "r") as f:
152
+ default_config = yaml.safe_load(f)
153
+ config = default_config
154
+
155
+ with open(args.config, "r") as f:
156
+ user_config = yaml.safe_load(f)
157
+ config.update(user_config)
158
+
159
+ # Setup an experiment folder:
160
+ os.makedirs(config['results_dir'], exist_ok=True) # Make results folder (holds all experiment subfolders)
161
+ experiment_dir = f"{config['results_dir']}/{config['run_name']}" # Create an experiment folder
162
+ checkpoint_dir = f"{experiment_dir}/checkpoints" # Stores saved model checkpoints
163
+ if rank == 0:
164
+ os.makedirs(checkpoint_dir, exist_ok=True)
165
+ logger = create_logger(experiment_dir)
166
+ logger.info(f"Experiment directory created at {experiment_dir}")
167
+ else:
168
+ logger = create_logger(None)
169
+
170
+ # Create model:
171
+ tokenizer = SoundStream(C=32, D=16, n_q=8, codebook_size=1024).to(device)
172
+ tokenizer_path=config["tokenizer_a_path"]
173
+ checkpoint = torch.load(tokenizer_path, map_location=f"cuda:{device}")
174
+ tokenizer.load_state_dict(checkpoint["model_state"])
175
+ tokenizer.eval()
176
+
177
+ latent_size = config['image_size'] // 8
178
+
179
+ assert config['image_size'] % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)."
180
+ num_cond = config['context_size']
181
+ model = AVCDiT_models[config['model']](context_size=num_cond, input_size=latent_size, in_channels=4, mode="a").to(device)
182
+
183
+ ema = deepcopy(model).to(device)
184
+ requires_grad(ema, False)
185
+
186
+ lr = float(config.get('lr', 1e-4))
187
+ for param in model.parameters():
188
+ param.requires_grad = False
189
+ for param in model.x_embedder_a.parameters():
190
+ param.requires_grad = True
191
+ model.pos_embed_a_cond.requires_grad = True
192
+ model.pos_embed_a_pred.requires_grad = True
193
+ for param in model.final_layer_a.parameters():
194
+ param.requires_grad = True
195
+ for i, block in enumerate(model.blocks):
196
+ for name, param in block.named_parameters():
197
+ if name.startswith("mlp."):
198
+ param.requires_grad = True
199
+
200
+ opt = torch.optim.AdamW(
201
+ filter(lambda p: p.requires_grad, model.parameters()),
202
+ lr=lr, weight_decay=0
203
+ )
204
+
205
+ bfloat_enable = bool(hasattr(args, 'bfloat16') and args.bfloat16)
206
+ if bfloat_enable:
207
+ scaler = torch.amp.GradScaler()
208
+
209
+ start_epoch, train_steps = load_checkpoint_if_available(
210
+ model, ema, opt, scaler if bfloat_enable else None, config, device, logger, args
211
+ )
212
+
213
+ print("Trainable Parameters: ")
214
+ for name, param in model.named_parameters():
215
+ if param.requires_grad:
216
+ print(f" - {name}: {tuple(param.shape)}")
217
+ # =======================================================================================#
218
+
219
+ # ~40% speedup but might leads to worse performance depending on pytorch version
220
+ if args.torch_compile:
221
+ model = torch.compile(model)
222
+ model = DDP(model, device_ids=[device])
223
+ diffusion = create_diffusion(timestep_respacing="") # default: 1000 steps, linear noise schedule
224
+ # ,predict_xstart=True
225
+ logger.info(f"AVCDiT Parameters: {sum(p.numel() for p in model.parameters()):,}")
226
+
227
+ train_dataset = []
228
+ test_dataset = []
229
+
230
+ for dataset_name in config["datasets"]:
231
+ data_config = config["datasets"][dataset_name]
232
+
233
+ for data_split_type in ["train", "test"]:
234
+ if data_split_type in data_config:
235
+ goals_per_obs = int(data_config["goals_per_obs"])
236
+ if data_split_type == 'test':
237
+ goals_per_obs = 4 # standardize testing
238
+
239
+ if "distance" in data_config:
240
+ min_dist_cat=data_config["distance"]["min_dist_cat"]
241
+ max_dist_cat=data_config["distance"]["max_dist_cat"]
242
+ else:
243
+ min_dist_cat=config["distance"]["min_dist_cat"]
244
+ max_dist_cat=config["distance"]["max_dist_cat"]
245
+
246
+ if "len_traj_pred" in data_config:
247
+ len_traj_pred=data_config["len_traj_pred"]
248
+ else:
249
+ len_traj_pred=config["len_traj_pred"]
250
+
251
+ dataset = TrainingDataset(
252
+ data_folder=data_config["data_folder"],
253
+ data_split_folder=data_config[data_split_type],
254
+ dataset_name=dataset_name,
255
+ image_size=config["image_size"],
256
+ min_dist_cat=min_dist_cat,
257
+ max_dist_cat=max_dist_cat,
258
+ len_traj_pred=len_traj_pred,
259
+ context_size=config["context_size"],
260
+ normalize=config["normalize"],
261
+ goals_per_obs=goals_per_obs,
262
+ transform=transform,
263
+ predefined_index=None,
264
+ traj_stride=1,
265
+ sample_rate=config["sample_rate"],
266
+ input_sr=config["input_sr"],
267
+ evaluate=(data_split_type=="test")
268
+ )
269
+ if data_split_type == "train":
270
+ train_dataset.append(dataset)
271
+ else:
272
+ test_dataset.append(dataset)
273
+ print(f"Dataset: {dataset_name} ({data_split_type}), size: {len(dataset)}")
274
+
275
+ # combine all the datasets from different robots
276
+ print(f"Combining {len(train_dataset)} datasets.")
277
+ train_dataset = ConcatDataset(train_dataset)
278
+ test_dataset = ConcatDataset(test_dataset)
279
+
280
+ sampler = DistributedSampler(
281
+ train_dataset,
282
+ num_replicas=dist.get_world_size(),
283
+ rank=rank,
284
+ shuffle=True,
285
+ seed=args.global_seed
286
+ )
287
+ loader = DataLoader(
288
+ train_dataset,
289
+ batch_size=config['batch_size'],
290
+ shuffle=False,
291
+ sampler=sampler,
292
+ num_workers=config['num_workers'],
293
+ pin_memory=True,
294
+ drop_last=True,
295
+ persistent_workers=True
296
+ )
297
+ logger.info(f"Dataset contains {len(train_dataset):,} images")
298
+
299
+ # Prepare models for training:
300
+ model.train() # important! This enables embedding dropout for classifier-free guidance
301
+ ema.eval() # EMA model should always be in eval mode
302
+
303
+ # Variables for monitoring/logging purposes:
304
+ log_steps = 0
305
+ running_loss = 0
306
+ start_time = time()
307
+
308
+ logger.info(f"Training for {args.epochs} epochs...")
309
+ for epoch in range(start_epoch, args.epochs):
310
+ sampler.set_epoch(epoch)
311
+ steps_per_epoch = len(loader)
312
+ if rank == 0:
313
+ logger.info(f"Epoch {epoch} contains {steps_per_epoch} steps.")
314
+ logger.info(f"Beginning epoch {epoch}...")
315
+
316
+ for _, x, y, diff, rel_t in loader:
317
+ x = x.to(device, non_blocking=True)
318
+ y = y.to(device, non_blocking=True)
319
+ diff = diff.to(device, non_blocking=True) # [REWARD]
320
+ rel_t = rel_t.to(device, non_blocking=True)
321
+
322
+ with torch.amp.autocast('cuda', enabled=bfloat_enable, dtype=torch.bfloat16):
323
+ with torch.no_grad():
324
+ # Map input images to latent space + normalize latents:
325
+ B, T = x.shape[:2]
326
+ x = x.flatten(0,1)
327
+ x = tokenizer.encoder(x)
328
+ x = x.unflatten(0, (B, T))
329
+
330
+ num_goals = T - num_cond
331
+ x_start = x[:, num_cond:].flatten(0, 1)
332
+ x_cond = x[:, :num_cond].unsqueeze(1).expand(B, num_goals, num_cond, x.shape[2], x.shape[3]).flatten(0, 1)
333
+ y = y.flatten(0, 1)
334
+ rel_t = rel_t.flatten(0, 1)
335
+
336
+ diff = diff.flatten(0, 1)
337
+ diff_tok = diff.unsqueeze(1).expand(-1, 16, -1)
338
+ x_start = torch.cat([x_start, diff_tok], dim=2)
339
+
340
+ t = torch.randint(0, diffusion.num_timesteps, (x_start.shape[0],), device=device)
341
+
342
+ model_kwargs = dict(y=y, x_cond=x_cond, rel_t=rel_t)
343
+ loss_dict = diffusion.training_losses(model, x_start, t, model_kwargs)
344
+ loss = loss_dict["loss"].mean()
345
+
346
+ if not bfloat_enable:
347
+ opt.zero_grad()
348
+ loss.backward()
349
+ opt.step()
350
+ else:
351
+ scaler.scale(loss).backward()
352
+ if config.get('grad_clip_val', 0) > 0:
353
+ scaler.unscale_(opt)
354
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config['grad_clip_val'])
355
+ scaler.step(opt)
356
+ scaler.update()
357
+
358
+ update_ema(ema, model.module)
359
+
360
+ # Log loss values:
361
+ running_loss += loss.detach().item()
362
+ log_steps += 1
363
+ train_steps += 1
364
+ if train_steps % args.log_every == 0:
365
+ # Measure training speed:
366
+ torch.cuda.synchronize()
367
+ end_time = time()
368
+ steps_per_sec = log_steps / (end_time - start_time)
369
+ samples_per_sec = dist.get_world_size()*x_cond.shape[0]*steps_per_sec
370
+ # Reduce loss history over all processes:
371
+ avg_loss = torch.tensor(running_loss / log_steps, device=device)
372
+ dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM)
373
+ avg_loss = avg_loss.item() / dist.get_world_size()
374
+ total_steps = len(loader) * args.epochs
375
+ progress_pct = train_steps / total_steps * 100
376
+
377
+ remaining_steps = total_steps - train_steps
378
+ eta_seconds = remaining_steps / steps_per_sec if steps_per_sec > 0 else 0
379
+ eta_hours = eta_seconds / 3600
380
+
381
+ logger.info(f"(step={train_steps:07d}) Train Loss: {avg_loss:.4f}, Train Steps/Sec: {steps_per_sec:.2f}, Samples/Sec: {samples_per_sec:.2f}")
382
+ logger.info(f"Progress: {progress_pct:.2f}% | ETA: {eta_hours:.1f}h")
383
+ running_loss = 0
384
+ log_steps = 0
385
+ start_time = time()
386
+
387
+ # Save DiT checkpoint:
388
+ if train_steps % args.ckpt_every == 0 and train_steps > 0:
389
+ if rank == 0:
390
+ checkpoint = {
391
+ "model": model.module.state_dict(),
392
+ "ema": ema.state_dict(),
393
+ "opt": opt.state_dict(),
394
+ "args": args,
395
+ "epoch": epoch,
396
+ "train_steps": train_steps
397
+ }
398
+ if bfloat_enable:
399
+ checkpoint.update({"scaler": scaler.state_dict()})
400
+ checkpoint_path = f"{checkpoint_dir}/latest.pth.tar"
401
+ torch.save(checkpoint, checkpoint_path)
402
+ if train_steps % (10*args.ckpt_every) == 0 and train_steps > 0:
403
+ checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pth.tar"
404
+ torch.save(checkpoint, checkpoint_path)
405
+ logger.info(f"Saved checkpoint to {checkpoint_path}")
406
+
407
+ if train_steps % args.eval_every == 0 and train_steps > 0:
408
+ eval_start_time = time()
409
+ save_dir = os.path.join(experiment_dir, str(train_steps))
410
+ save_dir_train = os.path.join(experiment_dir, f"{train_steps}_train")
411
+ evaluate(ema, tokenizer, diffusion, test_dataset, rank, config["batch_size"], config["num_workers"], latent_size, device, save_dir_train, args.global_seed, bfloat_enable, num_cond, config["sample_rate"], config["input_sr"], logger)
412
+ dist.barrier()
413
+ eval_end_time = time()
414
+ eval_time = eval_end_time - eval_start_time
415
+
416
+ model.eval() # important! This disables randomized embedding dropout
417
+ # do any sampling/FID calculation/etc. with ema (or model) in eval mode ...
418
+
419
+ logger.info("Done!")
420
+ cleanup()
421
+
422
+
423
+ def denormalize_dis(ndata: float, min_v=-20.0, max_v=20.0, scale=0.15):
424
+ n01 = (float(ndata) + 1.0) / 2.0
425
+ raw = n01 * (max_v - min_v) + min_v
426
+ return raw * scale
427
+
428
+ @torch.no_grad()
429
+ def evaluate(model, vae, diffusion, test_dataloaders, rank, batch_size, num_workers, latent_size, device, save_dir, seed, bfloat_enable, num_cond, sample_rate, input_sr, logger):
430
+ sampler = DistributedSampler(
431
+ test_dataloaders,
432
+ num_replicas=dist.get_world_size(),
433
+ rank=rank,
434
+ shuffle=True,
435
+ seed=seed
436
+ )
437
+ loader = DataLoader(
438
+ test_dataloaders,
439
+ batch_size=batch_size,
440
+ shuffle=False,
441
+ sampler=sampler,
442
+ num_workers=num_workers,
443
+ pin_memory=True,
444
+ drop_last=True
445
+ )
446
+
447
+ down_resampler = torchaudio.transforms.Resample(orig_freq=input_sr, new_freq=sample_rate, lowpass_filter_width=64).to(device, dtype=torch.bfloat16) # [RESAMPLE]
448
+ mel_tf = build_mel_transform(
449
+ sample_rate=sample_rate,
450
+ n_fft=1024, win_length=1024, hop_length=256,
451
+ n_mels=80, power=1.0,
452
+ device=device,
453
+ )
454
+ # Run for 1 step
455
+ for _, x, y, diff, rel_t, x_orig in loader:
456
+ x = x.to(device)
457
+ y = y.to(device)
458
+ diff = diff.to(device).flatten(0, 1) # [REWARD]
459
+ rel_t = rel_t.to(device).flatten(0, 1)
460
+ x_orig = x_orig.to(device)
461
+ with torch.amp.autocast('cuda', enabled=True, dtype=torch.bfloat16):
462
+ B, T = x.shape[:2]
463
+ num_goals = T - num_cond
464
+ samples, diff_pred = model_forward_wrapper_a((model, diffusion, vae), x, y, num_timesteps=None, latent_size=latent_size, device=device, num_cond=num_cond, num_goals=num_goals, rel_t=rel_t)
465
+
466
+ decoded = down_resampler(samples)
467
+
468
+ x_start_pixels = x_orig[:, num_cond:].flatten(0, 1)
469
+ x_cond_pixels = x_orig[:, :num_cond].unsqueeze(1).expand(B, num_goals, num_cond, x_orig.shape[2], x_orig.shape[3]).flatten(0, 1)
470
+ break
471
+
472
+ if rank == 0:
473
+ os.makedirs(save_dir, exist_ok=True)
474
+
475
+ num_save = min(samples.shape[0], 10)
476
+
477
+ if diff is not None: # [REWARD]
478
+ mae = torch.mean(torch.abs(diff_pred - diff))
479
+ logger.info(f"Distance Diff MAE = {mae.item():.6f}")
480
+ mel_cosine_ls=[]
481
+ for i in range(num_save):
482
+ mel_cos = mel_cosine_stereo(x_start_pixels[i], decoded[i], sample_rate=sample_rate, mel_tf=mel_tf)
483
+ mel_cosine_ls.append(mel_cos)
484
+ ok = save_ref_hat_spectrogram_panel(
485
+ x_start_pixels[i], decoded[i],
486
+ out_path=f"{save_dir}/{i}_spectrograms.png",
487
+ n_fft=512, hop_length=160, win_length=400, pool=4,
488
+ title="gt vs pred"
489
+ )
490
+
491
+ torchaudio.save(f"{save_dir}/{i}_gen.wav", decoded[i].cpu().to(torch.float32), sample_rate=sample_rate)
492
+ torchaudio.save(f"{save_dir}/{i}_gt.wav", x_start_pixels[i].cpu().to(torch.float32), sample_rate=sample_rate)
493
+ torchaudio.save(f"{save_dir}/{i}_cond.wav", x_cond_pixels[i, -1].cpu().to(torch.float32), sample_rate=sample_rate)
494
+
495
+ logger.info("the first 10 mel cosine: " + ", ".join(f"{v:.6f}" for v in mel_cosine_ls))
496
+
497
+
498
+ def get_args_parser():
499
+ parser = argparse.ArgumentParser()
500
+ parser.add_argument("--config", type=str, required=True)
501
+ parser.add_argument("--epochs", type=int, default=300)
502
+ parser.add_argument("--global-seed", type=int, default=0)
503
+ parser.add_argument("--log-every", type=int, default=100)
504
+ parser.add_argument("--ckpt-every", type=int, default=2000)
505
+ parser.add_argument("--eval-every", type=int, default=5000)
506
+ parser.add_argument("--bfloat16", type=int, default=1)
507
+ parser.add_argument("--torch-compile", type=int, default=1)
508
+ parser.add_argument("--restart-from-checkpoint", type=int, default=0,
509
+ help="If 1, only load model weights and reset epoch/step to zero (cold start)")
510
+ return parser
511
+
512
+ if __name__ == "__main__":
513
+ args = get_args_parser().parse_args()
514
+ main(args)
train_avwm_stage3.py ADDED
@@ -0,0 +1,532 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # NoMaD, GNM, ViNT: https://github.com/robodhruv/visualnav-transformer
9
+ # --------------------------------------------------------
10
+
11
+ from inference_avwm import model_forward_wrapper_av
12
+ import torch
13
+ # the first flag below was False when we tested this script but True makes A100 training a lot faster:
14
+ torch.backends.cuda.matmul.allow_tf32 = True
15
+ torch.backends.cudnn.allow_tf32 = True
16
+
17
+ import matplotlib
18
+ matplotlib.use('Agg')
19
+ from collections import OrderedDict
20
+ from copy import deepcopy
21
+ from time import time
22
+ import argparse
23
+ import logging
24
+ import os
25
+ import matplotlib.pyplot as plt
26
+ import yaml
27
+
28
+
29
+ import torch.distributed as dist
30
+ from torch.nn.parallel import DistributedDataParallel as DDP
31
+ from torch.utils.data import DataLoader, ConcatDataset
32
+ from torch.utils.data.distributed import DistributedSampler
33
+ from diffusers.models import AutoencoderKL
34
+
35
+ from distributed import init_distributed
36
+ from models import AVCDiT_models
37
+ from diffusion import create_diffusion
38
+ from datasets import TrainingDataset
39
+ from misc import transform
40
+ from soundstream import SoundStream
41
+ import torchaudio
42
+ from eval_audio import build_mel_transform, mel_cosine_stereo, drms_avg_db_stereo, save_ref_hat_spectrogram_panel
43
+
44
+ #################################################################################
45
+ # Training Helper Functions #
46
+ #################################################################################
47
+
48
+
49
+ def load_checkpoint_if_available(model, ema, opt, scaler, config, device, logger, args):
50
+ start_epoch = 0
51
+ train_steps = 0
52
+ latest_path = os.path.join(config['results_dir'], config['run_name'], "checkpoints", "latest.pth.tar")
53
+ if os.path.isfile(latest_path) or config.get('from_checkpoint', 0):
54
+ latest_path = latest_path if os.path.isfile(latest_path) else config.get('from_checkpoint', 0)
55
+ print("Loading model from ", latest_path)
56
+ checkpoint = torch.load(latest_path, map_location=f"cuda:{device}", weights_only=False)
57
+
58
+ ema_ckp = {k.replace('_orig_mod.', ''): v for k, v in checkpoint["ema"].items()}
59
+ model.load_state_dict(ema_ckp, strict=False)
60
+ print("Model weights loaded.")
61
+ ema.load_state_dict(ema_ckp, strict=False)
62
+ print("EMA weights loaded.")
63
+
64
+ if args.restart_from_checkpoint:
65
+ logger.info("Restarting training: epoch and step counters set to 0.")
66
+ else:
67
+ if "opt" in checkpoint:
68
+ opt_ckp = {k.replace('_orig_mod.', ''): v for k, v in checkpoint["opt"].items()}
69
+ opt.load_state_dict(opt_ckp)
70
+ print("Optimizer state loaded.")
71
+ if "scaler" in checkpoint and scaler is not None:
72
+ scaler.load_state_dict(checkpoint["scaler"])
73
+ print("GradScaler state loaded.")
74
+ if "epoch" in checkpoint:
75
+ start_epoch = checkpoint["epoch"] + 1
76
+ if "train_steps" in checkpoint:
77
+ train_steps = checkpoint["train_steps"]
78
+ logger.info(f"Resuming from epoch {start_epoch}, step {train_steps}")
79
+
80
+ return start_epoch, train_steps
81
+
82
+
83
+ @torch.no_grad()
84
+ def update_ema(ema_model, model, decay=0.9999):
85
+ """
86
+ Step the EMA model towards the current model.
87
+ """
88
+ ema_params = OrderedDict(ema_model.named_parameters())
89
+ model_params = OrderedDict(model.named_parameters())
90
+
91
+ for name, param in model_params.items():
92
+ name = name.replace('_orig_mod.', '')
93
+ ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
94
+
95
+
96
+ def requires_grad(model, flag=True):
97
+ """
98
+ Set requires_grad flag for all parameters in a model.
99
+ """
100
+ for p in model.parameters():
101
+ p.requires_grad = flag
102
+
103
+
104
+ def cleanup():
105
+ """
106
+ End DDP training.
107
+ """
108
+ dist.destroy_process_group()
109
+
110
+
111
+ def create_logger(logging_dir):
112
+ """
113
+ Create a logger that writes to a log file and stdout.
114
+ """
115
+ if dist.get_rank() == 0: # real logger
116
+ logging.basicConfig(
117
+ level=logging.INFO,
118
+ format='[\033[34m%(asctime)s\033[0m] %(message)s',
119
+ datefmt='%Y-%m-%d %H:%M:%S',
120
+ handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
121
+ )
122
+ logger = logging.getLogger(__name__)
123
+ else: # dummy logger (does nothing)
124
+ logger = logging.getLogger(__name__)
125
+ logger.addHandler(logging.NullHandler())
126
+ return logger
127
+
128
+ #################################################################################
129
+ # Training Loop #
130
+ #################################################################################
131
+
132
+ def main(args):
133
+ """
134
+ Trains a new AVCDiT model.
135
+ """
136
+ assert torch.cuda.is_available(), "Training currently requires at least one GPU."
137
+
138
+ # Setup DDP:
139
+ _, rank, device, _ = init_distributed()
140
+ # rank = dist.get_rank()
141
+ seed = args.global_seed * dist.get_world_size() + rank
142
+ torch.manual_seed(seed)
143
+ print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
144
+ with open("config/eval_config.yaml", "r") as f:
145
+ default_config = yaml.safe_load(f)
146
+ config = default_config
147
+
148
+ with open(args.config, "r") as f:
149
+ user_config = yaml.safe_load(f)
150
+ config.update(user_config)
151
+
152
+ # Setup an experiment folder:
153
+ os.makedirs(config['results_dir'], exist_ok=True) # Make results folder (holds all experiment subfolders)
154
+ experiment_dir = f"{config['results_dir']}/{config['run_name']}" # Create an experiment folder
155
+ checkpoint_dir = f"{experiment_dir}/checkpoints" # Stores saved model checkpoints
156
+ if rank == 0:
157
+ os.makedirs(checkpoint_dir, exist_ok=True)
158
+ logger = create_logger(experiment_dir)
159
+ logger.info(f"Experiment directory created at {experiment_dir}")
160
+ else:
161
+ logger = create_logger(None)
162
+
163
+ # Create model:
164
+ tokenizer_v = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-ema").to(device)
165
+
166
+ tokenizer_a = SoundStream(C=32, D=16, n_q=8, codebook_size=1024).to(device)
167
+ tokenizer_a_path=config["tokenizer_a_path"]
168
+ tokenizer_a_checkpoint = torch.load(tokenizer_a_path, map_location=f"cuda:{device}")
169
+ tokenizer_a.load_state_dict(tokenizer_a_checkpoint["model_state"])
170
+ tokenizer_a.eval()
171
+
172
+ latent_size = config['image_size'] // 8
173
+
174
+ assert config['image_size'] % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)."
175
+ num_cond = config['context_size']
176
+ model = AVCDiT_models[config['model']](context_size=num_cond, input_size=latent_size, in_channels=4).to(device)
177
+
178
+ ema = deepcopy(model).to(device) # Create an EMA of the model for use after training
179
+ requires_grad(ema, False)
180
+
181
+ # Setup optimizer (we used default Adam betas=(0.9, 0.999) and a constant learning rate of 1e-4 in our paper):
182
+ lr = float(config.get('lr', 1e-4))
183
+ opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0)
184
+
185
+
186
+ bfloat_enable = bool(hasattr(args, 'bfloat16') and args.bfloat16)
187
+ if bfloat_enable:
188
+ scaler = torch.amp.GradScaler()
189
+
190
+ start_epoch, train_steps = load_checkpoint_if_available(
191
+ model, ema, opt, scaler if bfloat_enable else None, config, device, logger, args
192
+ )
193
+
194
+ # ~40% speedup but might leads to worse performance depending on pytorch version
195
+ if args.torch_compile:
196
+ model = torch.compile(model)
197
+ model = DDP(model, device_ids=[device])
198
+ diffusion = create_diffusion(timestep_respacing="", dual=True) # default: 1000 steps, linear noise schedule
199
+ # ,predict_xstart=True
200
+ logger.info(f"AVCDiT Parameters: {sum(p.numel() for p in model.parameters()):,}")
201
+
202
+ train_dataset = []
203
+ test_dataset = []
204
+
205
+ for dataset_name in config["datasets"]:
206
+ data_config = config["datasets"][dataset_name]
207
+
208
+ for data_split_type in ["train", "test"]:
209
+ if data_split_type in data_config:
210
+ goals_per_obs = int(data_config["goals_per_obs"])
211
+ if data_split_type == 'test':
212
+ goals_per_obs = 4 # standardize testing
213
+
214
+ if "distance" in data_config:
215
+ min_dist_cat=data_config["distance"]["min_dist_cat"]
216
+ max_dist_cat=data_config["distance"]["max_dist_cat"]
217
+ else:
218
+ min_dist_cat=config["distance"]["min_dist_cat"]
219
+ max_dist_cat=config["distance"]["max_dist_cat"]
220
+
221
+ if "len_traj_pred" in data_config:
222
+ len_traj_pred=data_config["len_traj_pred"]
223
+ else:
224
+ len_traj_pred=config["len_traj_pred"]
225
+
226
+ dataset = TrainingDataset(
227
+ data_folder=data_config["data_folder"],
228
+ data_split_folder=data_config[data_split_type],
229
+ dataset_name=dataset_name,
230
+ image_size=config["image_size"],
231
+ min_dist_cat=min_dist_cat,
232
+ max_dist_cat=max_dist_cat,
233
+ len_traj_pred=len_traj_pred,
234
+ context_size=config["context_size"],
235
+ normalize=config["normalize"],
236
+ goals_per_obs=goals_per_obs,
237
+ transform=transform,
238
+ predefined_index=None,
239
+ traj_stride=1,
240
+ sample_rate=config["sample_rate"],
241
+ # target_len=7840 #TODO
242
+ input_sr=config["input_sr"],
243
+ evaluate=(data_split_type=="test")
244
+ )
245
+ if data_split_type == "train":
246
+ train_dataset.append(dataset)
247
+ else:
248
+ test_dataset.append(dataset)
249
+ print(f"Dataset: {dataset_name} ({data_split_type}), size: {len(dataset)}")
250
+
251
+ # combine all the datasets from different robots
252
+ print(f"Combining {len(train_dataset)} datasets.")
253
+ train_dataset = ConcatDataset(train_dataset)
254
+ test_dataset = ConcatDataset(test_dataset)
255
+
256
+ sampler = DistributedSampler(
257
+ train_dataset,
258
+ num_replicas=dist.get_world_size(),
259
+ rank=rank,
260
+ shuffle=True,
261
+ seed=args.global_seed
262
+ )
263
+ loader = DataLoader(
264
+ train_dataset,
265
+ batch_size=config['batch_size'],
266
+ shuffle=False,
267
+ sampler=sampler,
268
+ num_workers=config['num_workers'],
269
+ pin_memory=True,
270
+ drop_last=True,
271
+ persistent_workers=True
272
+ )
273
+ logger.info(f"Dataset contains {len(train_dataset):,} images")
274
+
275
+ # Prepare models for training:
276
+ model.train() # important! This enables embedding dropout for classifier-free guidance
277
+ ema.eval() # EMA model should always be in eval mode
278
+
279
+ # Variables for monitoring/logging purposes:
280
+ log_steps = 0
281
+ running_loss = 0
282
+ start_time = time()
283
+
284
+ logger.info(f"Training for {args.epochs} epochs...")
285
+ for epoch in range(start_epoch, args.epochs):
286
+ sampler.set_epoch(epoch)
287
+ steps_per_epoch = len(loader)
288
+ if rank == 0:
289
+ logger.info(f"Epoch {epoch} contains {steps_per_epoch} steps.")
290
+ logger.info(f"Beginning epoch {epoch}...")
291
+
292
+ for x_v, x_a, y, diff, rel_t in loader:
293
+ x_v = x_v.to(device, non_blocking=True)
294
+ x_a = x_a.to(device, non_blocking=True)
295
+ y = y.to(device, non_blocking=True)
296
+ diff = diff.to(device, non_blocking=True)
297
+ rel_t = rel_t.to(device, non_blocking=True)
298
+
299
+ with torch.amp.autocast('cuda', enabled=bfloat_enable, dtype=torch.bfloat16):
300
+ with torch.no_grad():
301
+ # Map input images to latent space + normalize latents:
302
+ B, T = x_v.shape[:2]
303
+ #=== vision observation encoding
304
+ x_v = x_v.flatten(0,1)
305
+ x_v = tokenizer_v.encode(x_v).latent_dist.sample().mul_(0.18215)
306
+ x_v = x_v.unflatten(0, (B, T))
307
+ #=== audio observation encoding
308
+ x_a = x_a.flatten(0,1)
309
+ x_a = tokenizer_a.encoder(x_a)
310
+ x_a = x_a.unflatten(0, (B, T))
311
+
312
+ num_goals = T - num_cond
313
+ #=== split into target and condition
314
+ x_v_start = x_v[:, num_cond:].flatten(0, 1)
315
+ x_v_cond = x_v[:, :num_cond].unsqueeze(1).expand(B, num_goals, num_cond, x_v.shape[2], x_v.shape[3], x_v.shape[4]).flatten(0, 1)
316
+ x_a_start = x_a[:, num_cond:].flatten(0, 1)
317
+ x_a_cond = x_a[:, :num_cond].unsqueeze(1).expand(B, num_goals, num_cond, x_a.shape[2], x_a.shape[3]).flatten(0, 1)
318
+ #===
319
+ y = y.flatten(0, 1)
320
+ rel_t = rel_t.flatten(0, 1)
321
+
322
+
323
+
324
+ diff = diff.flatten(0, 1) # [N, 1]
325
+ diff_tok = diff.unsqueeze(1).expand(-1, 16, -1) # [N, 64, 1]
326
+ x_a_start = torch.cat([x_a_start, diff_tok], dim=2) # [N, 64, 181]
327
+
328
+ t = torch.randint(0, diffusion.num_timesteps, (x_v_start.shape[0],), device=device)
329
+ model_kwargs = dict(y=y, x_v_cond=x_v_cond, x_a_cond=x_a_cond, rel_t=rel_t)
330
+ loss_dict = diffusion.training_losses(model, x_v_start, x_a_start, t, model_kwargs)
331
+ loss = loss_dict["loss"].mean()
332
+
333
+ if not bfloat_enable:
334
+ opt.zero_grad()
335
+ loss.backward()
336
+ opt.step()
337
+ else:
338
+ scaler.scale(loss).backward()
339
+ if config.get('grad_clip_val', 0) > 0:
340
+ scaler.unscale_(opt)
341
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config['grad_clip_val'])
342
+ scaler.step(opt)
343
+ scaler.update()
344
+
345
+ update_ema(ema, model.module)
346
+
347
+ # Log loss values:
348
+ running_loss += loss.detach().item()
349
+ log_steps += 1
350
+ train_steps += 1
351
+ if train_steps % args.log_every == 0:
352
+ # Measure training speed:
353
+ torch.cuda.synchronize()
354
+ end_time = time()
355
+ steps_per_sec = log_steps / (end_time - start_time)
356
+ samples_per_sec = dist.get_world_size()*x_v_cond.shape[0]*steps_per_sec
357
+ # Reduce loss history over all processes:
358
+ avg_loss = torch.tensor(running_loss / log_steps, device=device)
359
+ dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM)
360
+ avg_loss = avg_loss.item() / dist.get_world_size()
361
+ total_steps = len(loader) * args.epochs
362
+ progress_pct = train_steps / total_steps * 100
363
+
364
+ remaining_steps = total_steps - train_steps
365
+ eta_seconds = remaining_steps / steps_per_sec if steps_per_sec > 0 else 0
366
+ eta_hours = eta_seconds / 3600
367
+
368
+ logger.info(f"(step={train_steps:07d}) Train Loss: {avg_loss:.4f}, Train Steps/Sec: {steps_per_sec:.2f}, Samples/Sec: {samples_per_sec:.2f}")
369
+ logger.info(f"Progress: {progress_pct:.2f}% | ETA: {eta_hours:.1f}h")
370
+ # Reset monitoring variables:
371
+ running_loss = 0
372
+ log_steps = 0
373
+ start_time = time()
374
+
375
+ # Save DiT checkpoint:
376
+ if train_steps % args.ckpt_every == 0 and train_steps > 0:
377
+ if rank == 0:
378
+ checkpoint = {
379
+ "model": model.module.state_dict(),
380
+ "ema": ema.state_dict(),
381
+ "opt": opt.state_dict(),
382
+ "args": args,
383
+ "epoch": epoch,
384
+ "train_steps": train_steps
385
+ }
386
+ if bfloat_enable:
387
+ checkpoint.update({"scaler": scaler.state_dict()})
388
+ checkpoint_path = f"{checkpoint_dir}/latest.pth.tar"
389
+ torch.save(checkpoint, checkpoint_path)
390
+ if train_steps % (10*args.ckpt_every) == 0 and train_steps > 0:
391
+ checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pth.tar"
392
+ torch.save(checkpoint, checkpoint_path)
393
+ logger.info(f"Saved checkpoint to {checkpoint_path}")
394
+
395
+ if train_steps % args.eval_every == 0 and train_steps > 0:
396
+ eval_start_time = time()
397
+ # validation / test set evaluation
398
+ save_dir = os.path.join(experiment_dir, str(train_steps))
399
+ sim_score_val = evaluate(ema, tokenizer_v, tokenizer_a, diffusion, test_dataset, rank, config["batch_size"], config["num_workers"], latent_size, device, save_dir, args.global_seed, bfloat_enable, num_cond, config["sample_rate"], config["input_sr"], logger)
400
+ dist.barrier()
401
+ eval_end_time = time()
402
+ eval_time = eval_end_time - eval_start_time
403
+ # logger.info(f"(step={train_steps:07d}) Val Perceptual Loss: {sim_score_val:.4f}, Train Perceptual Loss: {sim_score_train:.4f}, Eval Time: {eval_time:.2f}")
404
+ logger.info(f"(step={train_steps:07d}) Val Perceptual Loss: {sim_score_val:.4f}, Eval Time: {eval_time:.2f}")
405
+
406
+ model.eval() # important! This disables randomized embedding dropout
407
+ # do any sampling/FID calculation/etc. with ema (or model) in eval mode ...
408
+
409
+ logger.info("Done!")
410
+ cleanup()
411
+
412
+ def denormalize_dis(ndata: float, min_v=-20.0, max_v=20.0, scale=0.15):
413
+ n01 = (float(ndata) + 1.0) / 2.0
414
+ raw = n01 * (max_v - min_v) + min_v
415
+ return raw * scale
416
+
417
+ @torch.no_grad
418
+ def evaluate(model, vae, sstream, diffusion, test_dataloaders, rank, batch_size, num_workers, latent_size, device, save_dir, seed, bfloat_enable, num_cond, sample_rate, input_sr, logger):
419
+ sampler = DistributedSampler(
420
+ test_dataloaders,
421
+ num_replicas=dist.get_world_size(),
422
+ rank=rank,
423
+ shuffle=True,
424
+ seed=seed
425
+ )
426
+ loader = DataLoader(
427
+ test_dataloaders,
428
+ batch_size=batch_size,
429
+ shuffle=False,
430
+ sampler=sampler,
431
+ num_workers=num_workers,
432
+ pin_memory=True,
433
+ drop_last=True
434
+ )
435
+ from dreamsim import dreamsim
436
+ eval_model, _ = dreamsim(pretrained=True)
437
+ score = torch.tensor(0.).to(device)
438
+ n_samples = torch.tensor(0).to(device)
439
+
440
+ down_resampler = torchaudio.transforms.Resample(orig_freq=input_sr, new_freq=sample_rate, lowpass_filter_width=64).to(device, dtype=torch.bfloat16)
441
+ mel_tf = build_mel_transform(
442
+ sample_rate=sample_rate,
443
+ n_fft=1024, win_length=1024, hop_length=256,
444
+ n_mels=80, power=1.0,
445
+ device=device, # or ref.device
446
+ )
447
+ # Run for 1 step
448
+ for x_v, x_a, y, diff, rel_t, x_a_orig in loader:
449
+ x_v = x_v.to(device)
450
+ x_a = x_a.to(device)
451
+ x_a_orig = x_a_orig.to(device)
452
+ y = y.to(device)
453
+ diff = diff.to(device).flatten(0, 1)
454
+ rel_t = rel_t.to(device).flatten(0, 1)
455
+ with torch.amp.autocast('cuda', enabled=True, dtype=torch.bfloat16):
456
+ B, T = x_v.shape[:2]
457
+ num_goals = T - num_cond
458
+ samples_v, samples_a, diff_pred = model_forward_wrapper_av((model, diffusion, vae, sstream), (x_v, x_a), y, num_timesteps=None, latent_size=latent_size, device=device, num_cond=num_cond, num_goals=num_goals, rel_t=rel_t)
459
+
460
+ samples_a = down_resampler(samples_a) #
461
+
462
+ x_start_pixels = x_v[:, num_cond:].flatten(0, 1)
463
+ x_cond_pixels = x_v[:, :num_cond].unsqueeze(1).expand(B, num_goals, num_cond, x_v.shape[2], x_v.shape[3], x_v.shape[4]).flatten(0, 1)
464
+ samples_v = samples_v * 0.5 + 0.5
465
+ x_start_pixels = x_start_pixels * 0.5 + 0.5
466
+ x_cond_pixels = x_cond_pixels * 0.5 + 0.5
467
+ res = eval_model(x_start_pixels, samples_v)
468
+ score += res.sum()
469
+ n_samples += len(res)
470
+
471
+ # x_start_audio = x_a[:, num_cond:].flatten(0, 1)
472
+ # x_cond_audio = x_a[:, :num_cond].unsqueeze(1).expand(B, num_goals, num_cond, x_a.shape[2], x_a.shape[3]).flatten(0, 1)
473
+ x_start_audio = x_a_orig[:, num_cond:].flatten(0, 1)
474
+ x_cond_audio = x_a_orig[:, :num_cond].unsqueeze(1).expand(B, num_goals, num_cond, x_a_orig.shape[2], x_a_orig.shape[3]).flatten(0, 1)
475
+ break
476
+
477
+ if rank == 0:
478
+ os.makedirs(save_dir, exist_ok=True)
479
+
480
+ if diff is not None:
481
+ mae = torch.mean(torch.abs(diff_pred - diff))
482
+ logger.info(f"Distance Diff MAE = {mae.item():.6f}")
483
+
484
+ mel_cosine_ls=[]
485
+ for i in range(min(samples_v.shape[0], 10)):
486
+ _, ax = plt.subplots(1,3,dpi=256)
487
+ ax[0].imshow((x_cond_pixels[i, -1].permute(1,2,0).cpu().numpy()*255).astype('uint8'))
488
+ ax[1].imshow((x_start_pixels[i].permute(1,2,0).cpu().numpy()*255).astype('uint8'))
489
+ ax[2].imshow((samples_v[i].permute(1,2,0).cpu().float().numpy()*255).astype('uint8'))
490
+ plt.savefig(f'{save_dir}/{i}.png')
491
+ plt.close()
492
+
493
+
494
+ mel_cos = mel_cosine_stereo(x_start_audio[i], samples_a[i], sample_rate=sample_rate, mel_tf=mel_tf)
495
+ mel_cosine_ls.append(mel_cos)
496
+ ok = save_ref_hat_spectrogram_panel(
497
+ x_start_audio[i], samples_a[i],
498
+ out_path=f"{save_dir}/{i}_spectrograms.png",
499
+ n_fft=512, hop_length=160, win_length=400, pool=4,
500
+ title="gt vs pred"
501
+ )
502
+
503
+ # sr = int(16000 * 7840 / 2400) #TODO
504
+ torchaudio.save(f"{save_dir}/{i}_gen.wav", samples_a[i].cpu().to(torch.float32), sample_rate=sample_rate)
505
+ torchaudio.save(f"{save_dir}/{i}_gt.wav", x_start_audio[i].cpu().to(torch.float32), sample_rate=sample_rate)
506
+ torchaudio.save(f"{save_dir}/{i}_cond.wav", x_cond_audio[i, -1].cpu().to(torch.float32), sample_rate=sample_rate)
507
+ logger.info("the first 10 mel cosine: " + ", ".join(f"{v:.6f}" for v in mel_cosine_ls))
508
+
509
+
510
+ dist.all_reduce(score)
511
+ dist.all_reduce(n_samples)
512
+ sim_score = score/n_samples
513
+ return sim_score
514
+
515
+
516
+ def get_args_parser():
517
+ parser = argparse.ArgumentParser()
518
+ parser.add_argument("--config", type=str, required=True)
519
+ parser.add_argument("--epochs", type=int, default=300)
520
+ parser.add_argument("--global-seed", type=int, default=0)
521
+ parser.add_argument("--log-every", type=int, default=100)
522
+ parser.add_argument("--ckpt-every", type=int, default=2000)
523
+ parser.add_argument("--eval-every", type=int, default=5000)
524
+ parser.add_argument("--bfloat16", type=int, default=1)
525
+ parser.add_argument("--torch-compile", type=int, default=1)
526
+ parser.add_argument("--restart-from-checkpoint", type=int, default=0,
527
+ help="If 1, only load model weights and reset epoch/step to zero (cold start)")
528
+ return parser
529
+
530
+ if __name__ == "__main__":
531
+ args = get_args_parser().parse_args()
532
+ main(args)