File size: 17,647 Bytes
c19aa83 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 | """
Flux LoRA DDP Training Script
- 2 GPU DDP via accelerate
- bf16 mixed precision
- Gradient checkpointing
- WebDataset loading
- Checkpoint every 1000 steps with auto-upload to HF
- Auto-resume from latest checkpoint
"""
import os
import sys
import time
import math
import torch
import torch.nn.functional as F
from pathlib import Path
from torch.utils.data import DataLoader
import webdataset as wds
from accelerate import Accelerator
from accelerate.utils import set_seed
from diffusers import FluxPipeline, FlowMatchEulerDiscreteScheduler
from diffusers.training_utils import compute_density_for_timestep_sampling
from peft import LoraConfig, get_peft_model
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from huggingface_hub import HfApi, upload_folder
from torchvision import transforms
from PIL import Image
import io
def get_args():
import argparse
p = argparse.ArgumentParser()
p.add_argument("--model-name", default="black-forest-labs/FLUX.1-dev")
p.add_argument("--data-dir", required=True)
p.add_argument("--output-dir", required=True)
p.add_argument("--batch-size", type=int, default=1)
p.add_argument("--gradient-accumulation", type=int, default=4)
p.add_argument("--learning-rate", type=float, default=1e-4)
p.add_argument("--lr-warmup-steps", type=int, default=100)
p.add_argument("--max-train-steps", type=int, default=100000)
p.add_argument("--save-steps", type=int, default=1000)
p.add_argument("--sample-steps", type=int, default=1000)
p.add_argument("--lora-rank", type=int, default=128)
p.add_argument("--lora-alpha", type=int, default=64)
p.add_argument("--max-grad-norm", type=float, default=1.0)
p.add_argument("--seed", type=int, default=42)
p.add_argument("--resolution", type=int, default=1024)
p.add_argument("--hf-user", default="memoryai")
p.add_argument("--hf-repo", default="4k-image-model-checkpoints")
return p.parse_args()
def create_webdataset(data_dir, resolution, tokenizer, tokenizer_2):
transform = transforms.Compose([
transforms.Resize(resolution, interpolation=transforms.InterpolationMode.LANCZOS),
transforms.CenterCrop(resolution),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
])
def process_sample(sample):
try:
image = sample.get("jpg") or sample.get("png") or sample.get("jpeg")
if image is None:
return None
if not isinstance(image, Image.Image):
image = Image.open(io.BytesIO(image)).convert("RGB")
else:
image = image.convert("RGB")
image = transform(image)
caption = sample.get("txt", "")
if isinstance(caption, bytes):
caption = caption.decode("utf-8")
tokens_1 = tokenizer(
caption, max_length=77, padding="max_length",
truncation=True, return_tensors="pt"
)
tokens_2 = tokenizer_2(
caption, max_length=512, padding="max_length",
truncation=True, return_tensors="pt"
)
return {
"pixel_values": image,
"input_ids_1": tokens_1.input_ids.squeeze(0),
"attention_mask_1": tokens_1.attention_mask.squeeze(0),
"input_ids_2": tokens_2.input_ids.squeeze(0),
"attention_mask_2": tokens_2.attention_mask.squeeze(0),
}
except Exception:
return None
shards = sorted([str(p) for p in Path(data_dir).glob("*.tar")])
if not shards:
raise ValueError(f"No .tar shards found in {data_dir}")
dataset = (
wds.WebDataset(shards, shardshuffle=1000, nodesplitter=wds.split_by_node, empty_check=False)
.decode("pil")
.shuffle(1000)
.map(process_sample)
.select(lambda x: x is not None)
.batched(1, collation_fn=lambda batch: {
"pixel_values": torch.stack([b["pixel_values"] for b in batch]),
"input_ids_1": torch.stack([b["input_ids_1"] for b in batch]),
"attention_mask_1": torch.stack([b["attention_mask_1"] for b in batch]),
"input_ids_2": torch.stack([b["input_ids_2"] for b in batch]),
"attention_mask_2": torch.stack([b["attention_mask_2"] for b in batch]),
})
)
return dataset
def find_latest_checkpoint(output_dir):
output_path = Path(output_dir)
if not output_path.exists():
return None, 0
checkpoints = sorted(
[d for d in output_path.iterdir() if d.is_dir() and d.name.startswith("checkpoint-")],
key=lambda p: int(p.name.split("-")[1]) if p.name.split("-")[1].isdigit() else 0,
)
if checkpoints:
latest = checkpoints[-1]
state_file = latest / "training_state.pt"
if state_file.exists():
state = torch.load(state_file, map_location="cpu")
return latest, state.get("global_step", 0)
return None, 0
def upload_checkpoint(output_dir, checkpoint_name, hf_user, hf_repo):
try:
repo_id = f"{hf_user}/{hf_repo}"
api = HfApi()
try:
api.repo_info(repo_id=repo_id, repo_type="model")
except Exception:
api.create_repo(repo_id=repo_id, repo_type="model", private=True)
ckpt_path = Path(output_dir) / checkpoint_name
if ckpt_path.exists():
path_in_repo = f"flux_lora_4k/{checkpoint_name}"
upload_folder(
folder_path=str(ckpt_path),
repo_id=repo_id,
path_in_repo=path_in_repo,
repo_type="model",
)
print(f" Uploaded {checkpoint_name} -> {repo_id}/{path_in_repo}")
samples_dir = Path(output_dir) / "samples"
if samples_dir.exists() and any(samples_dir.glob("*.png")):
upload_folder(
folder_path=str(samples_dir),
repo_id=repo_id,
path_in_repo="flux_lora_4k/samples",
repo_type="model",
)
except Exception as e:
print(f" Upload failed (non-fatal): {e}")
def generate_samples(accelerator, pipe, output_dir, step, prompts=None):
if not accelerator.is_main_process:
return
if prompts is None:
prompts = [
"A stunning 4K photograph of a mountain landscape at golden hour",
"A detailed close-up of a butterfly on a flower, 4K ultra HD",
"A modern city skyline at night with reflections, high resolution",
"A portrait of an elderly craftsman in his workshop, natural lighting",
]
samples_dir = Path(output_dir) / "samples"
samples_dir.mkdir(exist_ok=True)
try:
pipe.to(accelerator.device)
with torch.no_grad():
for i, prompt in enumerate(prompts):
image = pipe(
prompt=prompt,
num_inference_steps=20,
guidance_scale=3.5,
height=1024,
width=1024,
).images[0]
image.save(samples_dir / f"step_{step:06d}_{i}.png")
print(f" Samples saved at step {step}")
except Exception as e:
print(f" Sample generation failed (non-fatal): {e}")
def main():
args = get_args()
set_seed(args.seed)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation,
mixed_precision="bf16",
log_with=None,
)
if accelerator.is_main_process:
print(f" Devices: {accelerator.num_processes}")
print(f" Batch size (per device): {args.batch_size}")
print(f" Gradient accumulation: {args.gradient_accumulation}")
print(f" Effective batch size: {args.batch_size * args.gradient_accumulation * accelerator.num_processes}")
print(f" LoRA rank: {args.lora_rank}, alpha: {args.lora_alpha}")
print(f" Max steps: {args.max_train_steps}")
print(f" Save every: {args.save_steps} steps")
# Load tokenizers
tokenizer = CLIPTokenizer.from_pretrained(args.model_name, subfolder="tokenizer")
tokenizer_2 = T5TokenizerFast.from_pretrained(args.model_name, subfolder="tokenizer_2")
# Load text encoders
text_encoder = CLIPTextModel.from_pretrained(
args.model_name, subfolder="text_encoder", torch_dtype=torch.bfloat16
)
text_encoder_2 = T5EncoderModel.from_pretrained(
args.model_name, subfolder="text_encoder_2", torch_dtype=torch.bfloat16
)
text_encoder.requires_grad_(False)
text_encoder_2.requires_grad_(False)
# Load pipeline for VAE and transformer
pipe = FluxPipeline.from_pretrained(args.model_name, torch_dtype=torch.bfloat16)
vae = pipe.vae
transformer = pipe.transformer
noise_scheduler = pipe.scheduler
vae.requires_grad_(False)
# Apply LoRA to transformer
lora_config = LoraConfig(
r=args.lora_rank,
lora_alpha=args.lora_alpha,
target_modules=["to_q", "to_k", "to_v", "to_out.0", "proj_in", "proj_out",
"ff.net.0.proj", "ff.net.2"],
lora_dropout=0.0,
)
transformer = get_peft_model(transformer, lora_config)
transformer.enable_gradient_checkpointing()
if accelerator.is_main_process:
trainable = sum(p.numel() for p in transformer.parameters() if p.requires_grad)
total = sum(p.numel() for p in transformer.parameters())
print(f" Trainable params: {trainable:,} / {total:,} ({100*trainable/total:.2f}%)")
# Optimizer
optimizer = torch.optim.AdamW(
[p for p in transformer.parameters() if p.requires_grad],
lr=args.learning_rate,
betas=(0.9, 0.999),
weight_decay=0.01,
eps=1e-8,
)
# Dataset
dataset = create_webdataset(args.data_dir, args.resolution, tokenizer, tokenizer_2)
dataloader = DataLoader(
dataset, batch_size=None, num_workers=2, pin_memory=True,
prefetch_factor=2,
)
# LR Scheduler
from torch.optim.lr_scheduler import LambdaLR
def lr_lambda(step):
if step < args.lr_warmup_steps:
return step / max(1, args.lr_warmup_steps)
return 1.0
lr_scheduler = LambdaLR(optimizer, lr_lambda)
# Prepare with accelerate (dataloader excluded - WebDataset handles DDP splitting)
transformer, optimizer, lr_scheduler = accelerator.prepare(
transformer, optimizer, lr_scheduler
)
# Move frozen models to device
vae.to(accelerator.device, dtype=torch.bfloat16)
text_encoder.to(accelerator.device)
text_encoder_2.to(accelerator.device)
# Resume from checkpoint
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
resume_ckpt, global_step = find_latest_checkpoint(args.output_dir)
if resume_ckpt is not None:
if accelerator.is_main_process:
print(f" Resuming from {resume_ckpt.name} (step {global_step})")
state = torch.load(resume_ckpt / "training_state.pt", map_location="cpu")
optimizer.load_state_dict(state["optimizer"])
lr_scheduler.load_state_dict(state["lr_scheduler"])
# Load LoRA weights
from peft import set_peft_model_state_dict
lora_state = torch.load(resume_ckpt / "lora_weights.pt", map_location="cpu")
set_peft_model_state_dict(accelerator.unwrap_model(transformer), lora_state)
else:
if accelerator.is_main_process:
print(" Starting from scratch")
# Training loop
if accelerator.is_main_process:
print(f"\n Training started at step {global_step}...")
transformer.train()
step_times = []
data_iter = iter(dataloader)
while global_step < args.max_train_steps:
step_start = time.time()
try:
batch = next(data_iter)
except (StopIteration, Exception):
data_iter = iter(dataloader)
batch = next(data_iter)
with accelerator.accumulate(transformer):
pixel_values = batch["pixel_values"].to(dtype=torch.bfloat16)
# Encode images
with torch.no_grad():
latents = vae.encode(pixel_values).latent_dist.sample()
latents = (latents - vae.config.shift_factor) * vae.config.scaling_factor
# Pack latents for Flux
batch_size, channels, height, width = latents.shape
latents = latents.reshape(batch_size, channels, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5).reshape(batch_size, (height // 2) * (width // 2), channels * 4)
# Text encoding
text_output_1 = text_encoder(
batch["input_ids_1"], attention_mask=batch["attention_mask_1"]
)
pooled_prompt_embeds = text_output_1.pooler_output
text_output_2 = text_encoder_2(
batch["input_ids_2"], attention_mask=batch["attention_mask_2"]
)
prompt_embeds = text_output_2.last_hidden_state
# Sample noise and timesteps
noise = torch.randn_like(latents)
timesteps = torch.rand(batch_size, device=latents.device, dtype=torch.bfloat16)
# Flow matching: interpolate between noise and latents
sigmas = timesteps.view(-1, 1, 1)
noisy_latents = (1 - sigmas) * latents + sigmas * noise
# Predict velocity
model_pred = transformer(
hidden_states=noisy_latents,
timestep=timesteps * 1000,
encoder_hidden_states=prompt_embeds,
pooled_projections=pooled_prompt_embeds,
return_dict=False,
)[0]
# Flow matching loss: predict (noise - latents)
target = noise - latents
loss = F.mse_loss(model_pred, target, reduction="mean")
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(transformer.parameters(), args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
if accelerator.sync_gradients:
global_step += 1
step_time = time.time() - step_start
step_times.append(step_time)
# Logging
if global_step % 50 == 0 and accelerator.is_main_process:
avg_time = sum(step_times[-50:]) / len(step_times[-50:])
steps_remaining = args.max_train_steps - global_step
eta_hours = (steps_remaining * avg_time) / 3600
print(
f" Step {global_step}/{args.max_train_steps} | "
f"Loss: {loss.item():.4f} | "
f"LR: {lr_scheduler.get_last_lr()[0]:.2e} | "
f"Time/step: {avg_time:.2f}s | "
f"ETA: {eta_hours:.1f}h"
)
# Save checkpoint
if global_step % args.save_steps == 0:
if accelerator.is_main_process:
ckpt_name = f"checkpoint-{global_step}"
ckpt_path = output_dir / ckpt_name
ckpt_path.mkdir(exist_ok=True)
# Save LoRA weights
from peft import get_peft_model_state_dict
lora_state = get_peft_model_state_dict(accelerator.unwrap_model(transformer))
torch.save(lora_state, ckpt_path / "lora_weights.pt")
# Save training state
torch.save({
"global_step": global_step,
"optimizer": optimizer.state_dict(),
"lr_scheduler": lr_scheduler.state_dict(),
}, ckpt_path / "training_state.pt")
print(f" Checkpoint saved: {ckpt_name}")
# Upload to HF
upload_checkpoint(
args.output_dir, ckpt_name, args.hf_user, args.hf_repo
)
# Clean old checkpoints (keep last 3)
all_ckpts = sorted(
[d for d in output_dir.iterdir() if d.is_dir() and d.name.startswith("checkpoint-")],
key=lambda p: int(p.name.split("-")[1]),
)
for old_ckpt in all_ckpts[:-3]:
import shutil
shutil.rmtree(old_ckpt)
print(f" Removed old: {old_ckpt.name}")
accelerator.wait_for_everyone()
# Generate samples
if global_step % args.sample_steps == 0:
if accelerator.is_main_process:
generate_samples(accelerator, pipe, args.output_dir, global_step)
# Final save
if accelerator.is_main_process:
final_path = output_dir / "final"
final_path.mkdir(exist_ok=True)
from peft import get_peft_model_state_dict
lora_state = get_peft_model_state_dict(accelerator.unwrap_model(transformer))
torch.save(lora_state, final_path / "lora_weights.pt")
torch.save({"global_step": global_step}, final_path / "training_state.pt")
print(f"\n Training complete! Final model saved at step {global_step}")
upload_checkpoint(args.output_dir, "final", args.hf_user, args.hf_repo)
accelerator.end_training()
if __name__ == "__main__":
main()
|