krystv commited on
Commit
80f3820
·
verified ·
1 Parent(s): 8a00562

Upload liquid_diffusion/trainer.py

Browse files
Files changed (1) hide show
  1. liquid_diffusion/trainer.py +193 -0
liquid_diffusion/trainer.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Rectified Flow Training for LiquidDiffusion
3
+
4
+ Training Objective (Rectified Flow):
5
+ x_t = (1-t)*x0 + t*x1, t ~ U[0,1], x1 ~ N(0,I)
6
+ v_target = x1 - x0 (constant velocity)
7
+ L = E[||v_θ(x_t, t) - v_target||²] (simple MSE)
8
+
9
+ Sampling (Euler ODE):
10
+ Start from x_1 ~ N(0,I), integrate backward:
11
+ x_{t-dt} = x_t - v_θ(x_t, t) * dt
12
+
13
+ References:
14
+ [1] Liu et al., "Flow Straight and Fast: Rectified Flow", ICLR 2023
15
+ [2] Lee et al., "Improving the Training of Rectified Flows", 2024
16
+ """
17
+
18
+ import math
19
+ import copy
20
+ import os
21
+ import time
22
+ import json
23
+ import torch
24
+ import torch.nn as nn
25
+ import torch.nn.functional as F
26
+ from torch.utils.data import DataLoader, Dataset
27
+ from torchvision import transforms
28
+ from torchvision.utils import save_image, make_grid
29
+
30
+
31
+ class RectifiedFlowTrainer:
32
+ """Trainer for LiquidDiffusion using Rectified Flow objective."""
33
+
34
+ def __init__(self, model, optimizer=None, lr=1e-4, weight_decay=0.01,
35
+ ema_decay=0.9999, grad_clip=1.0, time_sampling="logit_normal",
36
+ logit_normal_mean=0.0, logit_normal_std=1.0, device="cuda",
37
+ use_amp=True, amp_dtype="float16"):
38
+ self.model = model.to(device)
39
+ self.device = device
40
+ self.ema_decay = ema_decay
41
+ self.grad_clip = grad_clip
42
+ self.time_sampling = time_sampling
43
+ self.logit_normal_mean = logit_normal_mean
44
+ self.logit_normal_std = logit_normal_std
45
+ self.use_amp = use_amp and device == "cuda"
46
+ self.amp_dtype = getattr(torch, amp_dtype) if self.use_amp else torch.float32
47
+
48
+ if optimizer is None:
49
+ self.optimizer = torch.optim.AdamW(
50
+ model.parameters(), lr=lr, weight_decay=weight_decay, betas=(0.9, 0.999))
51
+ else:
52
+ self.optimizer = optimizer
53
+
54
+ self.scaler = torch.amp.GradScaler("cuda", enabled=(self.use_amp and amp_dtype == "float16"))
55
+ self.ema_model = self._build_ema()
56
+ self.step = 0
57
+ self.losses = []
58
+
59
+ def _build_ema(self):
60
+ ema = copy.deepcopy(self.model)
61
+ ema.eval()
62
+ for p in ema.parameters():
63
+ p.requires_grad_(False)
64
+ return ema
65
+
66
+ @torch.no_grad()
67
+ def _update_ema(self):
68
+ for ema_p, model_p in zip(self.ema_model.parameters(), self.model.parameters()):
69
+ ema_p.data.mul_(self.ema_decay).add_(model_p.data, alpha=1 - self.ema_decay)
70
+
71
+ def _sample_time(self, batch_size):
72
+ eps = 1e-5
73
+ if self.time_sampling == "uniform":
74
+ return torch.rand(batch_size, device=self.device) * (1 - 2*eps) + eps
75
+ elif self.time_sampling == "logit_normal":
76
+ u = torch.randn(batch_size, device=self.device) * self.logit_normal_std + self.logit_normal_mean
77
+ return torch.sigmoid(u).clamp(eps, 1 - eps)
78
+ raise ValueError(f"Unknown time_sampling: {self.time_sampling}")
79
+
80
+ def train_step(self, x0):
81
+ self.model.train()
82
+ x1 = torch.randn_like(x0)
83
+ t = self._sample_time(x0.shape[0])
84
+ t_expand = t[:, None, None, None]
85
+ x_t = (1 - t_expand) * x0 + t_expand * x1
86
+ v_target = x1 - x0
87
+
88
+ with torch.amp.autocast(self.device, dtype=self.amp_dtype, enabled=self.use_amp):
89
+ v_pred = self.model(x_t, t)
90
+ loss = F.mse_loss(v_pred, v_target)
91
+
92
+ self.optimizer.zero_grad(set_to_none=True)
93
+ self.scaler.scale(loss).backward()
94
+ if self.grad_clip > 0:
95
+ self.scaler.unscale_(self.optimizer)
96
+ grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip)
97
+ else:
98
+ grad_norm = torch.tensor(0.0)
99
+ self.scaler.step(self.optimizer)
100
+ self.scaler.update()
101
+ self._update_ema()
102
+ self.step += 1
103
+ loss_val = loss.item()
104
+ self.losses.append(loss_val)
105
+ return {'loss': loss_val, 'grad_norm': grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm, 'step': self.step}
106
+
107
+ @torch.no_grad()
108
+ def sample(self, batch_size=4, image_size=256, channels=3, num_steps=50, use_ema=True):
109
+ model = self.ema_model if use_ema else self.model
110
+ model.eval()
111
+ z = torch.randn(batch_size, channels, image_size, image_size, device=self.device)
112
+ dt = 1.0 / num_steps
113
+ for i in range(num_steps, 0, -1):
114
+ t = torch.full((batch_size,), i / num_steps, device=self.device)
115
+ with torch.amp.autocast(self.device, dtype=self.amp_dtype, enabled=self.use_amp):
116
+ v = model(z, t)
117
+ if self.use_amp and self.amp_dtype == torch.float16:
118
+ v = v.float()
119
+ z = z - v * dt
120
+ return z.clamp(-1, 1)
121
+
122
+ def save_checkpoint(self, path, extra=None):
123
+ ckpt = {'model': self.model.state_dict(), 'ema_model': self.ema_model.state_dict(),
124
+ 'optimizer': self.optimizer.state_dict(), 'scaler': self.scaler.state_dict(),
125
+ 'step': self.step, 'losses': self.losses[-1000:]}
126
+ if extra: ckpt.update(extra)
127
+ os.makedirs(os.path.dirname(path) if os.path.dirname(path) else '.', exist_ok=True)
128
+ torch.save(ckpt, path)
129
+
130
+ def load_checkpoint(self, path):
131
+ ckpt = torch.load(path, map_location=self.device, weights_only=False)
132
+ self.model.load_state_dict(ckpt['model'])
133
+ self.ema_model.load_state_dict(ckpt['ema_model'])
134
+ self.optimizer.load_state_dict(ckpt['optimizer'])
135
+ if 'scaler' in ckpt: self.scaler.load_state_dict(ckpt['scaler'])
136
+ self.step = ckpt.get('step', 0)
137
+ self.losses = ckpt.get('losses', [])
138
+ print(f"Resumed from step {self.step}")
139
+
140
+
141
+ class ImageDataset(Dataset):
142
+ """Image dataset from local folder or HuggingFace Hub."""
143
+ def __init__(self, source, image_size=256, split="train",
144
+ image_column="image", max_samples=None, hf_dataset=None):
145
+ self.image_size = image_size
146
+ self.image_column = image_column
147
+ self.transform = transforms.Compose([
148
+ transforms.Resize(image_size, interpolation=transforms.InterpolationMode.LANCZOS),
149
+ transforms.CenterCrop(image_size),
150
+ transforms.RandomHorizontalFlip(),
151
+ transforms.ToTensor(),
152
+ transforms.Normalize([0.5], [0.5]),
153
+ ])
154
+ if hf_dataset is not None:
155
+ self.data = hf_dataset
156
+ self.mode = "hf"
157
+ elif source and os.path.isdir(source):
158
+ from glob import glob
159
+ self.files = []
160
+ for ext in ['*.png', '*.jpg', '*.jpeg', '*.webp', '*.bmp']:
161
+ self.files.extend(glob(os.path.join(source, '**', ext), recursive=True))
162
+ self.files.sort()
163
+ if max_samples: self.files = self.files[:max_samples]
164
+ self.mode = "folder"
165
+ else:
166
+ from datasets import load_dataset
167
+ self.data = load_dataset(source, split=split)
168
+ if max_samples: self.data = self.data.select(range(min(max_samples, len(self.data))))
169
+ self.mode = "hf"
170
+
171
+ def __len__(self):
172
+ return len(self.files) if self.mode == "folder" else len(self.data)
173
+
174
+ def __getitem__(self, idx):
175
+ if self.mode == "folder":
176
+ from PIL import Image
177
+ img = Image.open(self.files[idx]).convert("RGB")
178
+ else:
179
+ img = self.data[idx][self.image_column]
180
+ if not hasattr(img, 'convert'):
181
+ from PIL import Image as PILImage
182
+ img = PILImage.fromarray(img)
183
+ img = img.convert("RGB")
184
+ return self.transform(img)
185
+
186
+
187
+ def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
188
+ def lr_lambda(step):
189
+ if step < num_warmup_steps:
190
+ return float(step) / float(max(1, num_warmup_steps))
191
+ progress = float(step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
192
+ return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
193
+ return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)