File size: 22,909 Bytes
b373569 85eb3ff 5a193ef b373569 d79aee0 5a193ef 85eb3ff b373569 85eb3ff b373569 5a193ef b373569 5f4163e d79aee0 b373569 85eb3ff b373569 85eb3ff b373569 bdb80f0 b373569 85eb3ff 5f4163e b373569 bdb80f0 b373569 5a193ef d79aee0 5a193ef 85eb3ff 5a193ef faacfe9 5a193ef b476a30 5a193ef faacfe9 5a193ef b476a30 5a193ef b476a30 5a193ef b373569 85eb3ff faacfe9 85eb3ff b373569 d79aee0 b373569 5a193ef 85eb3ff 5a193ef b373569 5a193ef b373569 d79aee0 85eb3ff 5a193ef b373569 d79aee0 b373569 d79aee0 b373569 85eb3ff d79aee0 85eb3ff b373569 85eb3ff d79aee0 b373569 d79aee0 85eb3ff d79aee0 b373569 d79aee0 85eb3ff d79aee0 b373569 d79aee0 85eb3ff d79aee0 b373569 d79aee0 5a193ef d79aee0 85eb3ff d79aee0 85eb3ff d79aee0 5a193ef b373569 5a193ef d79aee0 b373569 85eb3ff d79aee0 b373569 d79aee0 b373569 85eb3ff 5a193ef b373569 69bd807 85eb3ff bdb80f0 b373569 5a193ef b373569 5a193ef 85eb3ff d79aee0 5a193ef 11b11fa 5a193ef 85eb3ff b373569 5a193ef 85eb3ff 5a193ef 85eb3ff 5a193ef 5f4163e 11b11fa d79aee0 b373569 d79aee0 b373569 5a193ef b373569 5a193ef b373569 5a193ef b373569 d79aee0 5a193ef b373569 5a193ef d79aee0 b373569 d79aee0 b373569 5a193ef d79aee0 5a193ef b373569 d79aee0 5a193ef d79aee0 b373569 5a193ef b373569 5f4163e 5a193ef 5f4163e 5a193ef 5f4163e 5a193ef d79aee0 5a193ef 85eb3ff d79aee0 5a193ef 32ae43b 5a193ef 32ae43b 5a193ef 32ae43b 5a193ef 32ae43b 11b11fa d79aee0 11b11fa d79aee0 5a193ef 11b11fa 5a193ef 85eb3ff 11b11fa 5a193ef 11b11fa 5a193ef 11b11fa 5a193ef 11b11fa 5a193ef 11b11fa 5a193ef 11b11fa 5a193ef 11b11fa d79aee0 5a193ef 85eb3ff d79aee0 85eb3ff d79aee0 85eb3ff b373569 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 | """
Fine-tune Flux with LoRA - 2 GPU split (encode on GPU0, train on GPU1).
Reads webdataset shards directly. Supports resume from checkpoint.
Follows diffusers reference implementation for correct flow matching.
"""
import argparse
import gc
import io
import math
import time
from pathlib import Path
import torch
import torch.nn.functional as F
import webdataset as wds
from PIL import Image
from torchvision import transforms
from peft import LoraConfig, get_peft_model, set_peft_model_state_dict
def get_train_transforms(resolution=1024):
return transforms.Compose([
transforms.Resize(resolution, interpolation=transforms.InterpolationMode.LANCZOS),
transforms.CenterCrop(resolution),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
])
def collate_batch(samples):
images = torch.stack([s["image"] for s in samples])
captions = [s["caption"] for s in samples]
return {"image": images, "caption": captions}
def create_webdataset(data_dir, resolution=1024, batch_size=1):
transform = get_train_transforms(resolution)
def preprocess(sample):
try:
image = sample["jpg"]
if isinstance(image, bytes):
image = Image.open(io.BytesIO(image)).convert("RGB")
caption = sample.get("txt", b"")
if isinstance(caption, bytes):
caption = caption.decode("utf-8")
return {"image": transform(image), "caption": caption}
except Exception:
return None
tar_files = sorted(Path(data_dir).glob("*.tar"))
if not tar_files:
raise ValueError(f"No tar files found in {data_dir}")
print(f" Found {len(tar_files)} shards")
dataset = (
wds.WebDataset([str(f) for f in tar_files], shardshuffle=True, empty_check=False)
.shuffle(1000)
.decode("pil")
.map(preprocess)
.select(lambda x: x is not None)
.batched(batch_size, collation_fn=collate_batch)
)
return dataset, len(tar_files)
def pack_latents(latents, batch_size, num_channels, height, width):
latents = latents.view(batch_size, num_channels, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels * 4)
return latents
def unpack_latents(latents, height, width, num_channels):
batch_size = latents.shape[0]
latents = latents.reshape(batch_size, height // 2, width // 2, num_channels, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)
latents = latents.reshape(batch_size, num_channels, height, width)
return latents
def prepare_latent_image_ids(height, width, device, dtype):
latent_image_ids = torch.zeros(height, width, 3, device=device, dtype=dtype)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height, device=device, dtype=dtype)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width, device=device, dtype=dtype)[None, :]
return latent_image_ids.reshape(height * width, 3)
def compute_density_for_timestep_sampling(weighting_scheme, batch_size, logit_mean=0.0, logit_std=1.0):
if weighting_scheme == "logit_normal":
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,))
u = torch.sigmoid(u)
elif weighting_scheme == "mode":
u = torch.rand(batch_size)
u = 1 - u - 0.2 * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
else:
u = torch.rand(batch_size)
return u
def compute_loss_weighting(weighting_scheme, sigmas):
if weighting_scheme == "sigma_sqrt":
weighting = (sigmas ** -2.0)
return weighting.clamp(max=10.0)
elif weighting_scheme == "cosmap":
return 2.0 / (math.pi * (1 - 2 * sigmas + 2 * sigmas ** 2))
else:
return torch.ones_like(sigmas)
def find_latest_checkpoint(output_dir):
output_dir = Path(output_dir)
if not output_dir.exists():
return None, 0
checkpoints = sorted(
[d for d in output_dir.iterdir() if d.is_dir() and d.name.startswith("checkpoint-")],
key=lambda p: int(p.name.split("-")[1]) if p.name.split("-")[1].isdigit() else 0,
)
if checkpoints:
step = int(checkpoints[-1].name.split("-")[1])
return checkpoints[-1], step
return None, 0
@torch.no_grad()
def generate_samples(
transformer, vae, text_encoder, text_encoder_2,
tokenizer, tokenizer_2,
prompts, output_dir, global_step,
encode_device, train_device,
num_inference_steps=28, guidance_scale=3.5,
):
from diffusers import FluxPipeline
output_dir = Path(output_dir) / "samples"
output_dir.mkdir(parents=True, exist_ok=True)
transformer.eval()
# Move all components to same device for inference
gen_device = train_device
vae.to(gen_device)
text_encoder.to(gen_device)
text_encoder_2.to(gen_device)
try:
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
transformer=transformer,
vae=vae,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
torch_dtype=torch.bfloat16,
)
pipe = pipe.to(gen_device)
for i, prompt in enumerate(prompts):
image = pipe(
prompt=prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
height=512,
width=512,
).images[0]
image.save(output_dir / f"step_{global_step:06d}_sample_{i}.png")
del pipe
except Exception as e:
print(f" WARNING: Sample generation failed: {e}")
# Move components back to encode_device for training
vae.to(encode_device)
text_encoder.to(encode_device)
text_encoder_2.to(encode_device)
transformer.train()
torch.cuda.empty_cache()
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model-name", default="black-forest-labs/FLUX.1-dev")
parser.add_argument("--data-dir", type=Path, required=True)
parser.add_argument("--output-dir", type=Path, required=True)
parser.add_argument("--cache-dir", default="/data0/models")
parser.add_argument("--resolution", type=int, default=1024)
parser.add_argument("--batch-size", type=int, default=1)
parser.add_argument("--gradient-accumulation", type=int, default=8)
parser.add_argument("--learning-rate", type=float, default=1e-4)
parser.add_argument("--lr-scheduler", default="constant")
parser.add_argument("--lr-warmup-steps", type=int, default=100)
parser.add_argument("--max-train-steps", type=int, default=999999999)
parser.add_argument("--save-steps", type=int, default=2000)
parser.add_argument("--sample-steps", type=int, default=2000)
parser.add_argument("--lora-rank", type=int, default=128)
parser.add_argument("--lora-alpha", type=int, default=64)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--encode-device", default="cuda:0")
parser.add_argument("--train-device", default="cuda:1")
parser.add_argument("--resume-from-checkpoint", default="auto")
parser.add_argument("--guidance-scale", type=float, default=1.0)
parser.add_argument("--weighting-scheme", default="none", choices=["none", "logit_normal", "mode", "sigma_sqrt", "cosmap"])
parser.add_argument("--logit-mean", type=float, default=0.0)
parser.add_argument("--logit-std", type=float, default=1.0)
parser.add_argument("--max-grad-norm", type=float, default=1.0)
args = parser.parse_args()
args.output_dir.mkdir(parents=True, exist_ok=True)
torch.manual_seed(args.seed)
encode_device = torch.device(args.encode_device)
train_device = torch.device(args.train_device)
if torch.cuda.device_count() < 2:
print(" Only 1 GPU, using same device for encode + train")
encode_device = torch.device("cuda:0")
train_device = torch.device("cuda:0")
# Resume logic
resume_path, resume_step = None, 0
if args.resume_from_checkpoint == "auto":
resume_path, resume_step = find_latest_checkpoint(args.output_dir)
if resume_path:
print(f" Resuming from {resume_path} (step {resume_step})")
# Load tokenizers
print(" Loading tokenizers...")
from transformers import CLIPTokenizer, T5TokenizerFast
tokenizer = CLIPTokenizer.from_pretrained(args.model_name, subfolder="tokenizer", cache_dir=args.cache_dir)
tokenizer_2 = T5TokenizerFast.from_pretrained(args.model_name, subfolder="tokenizer_2", cache_dir=args.cache_dir)
# Load VAE + text encoders on encode_device
print(f" Loading VAE + text encoders on {encode_device}...")
from diffusers import AutoencoderKL
from transformers import CLIPTextModel, T5EncoderModel
vae = AutoencoderKL.from_pretrained(
args.model_name, subfolder="vae", torch_dtype=torch.bfloat16, cache_dir=args.cache_dir
).to(encode_device).eval()
vae.requires_grad_(False)
text_encoder = CLIPTextModel.from_pretrained(
args.model_name, subfolder="text_encoder", torch_dtype=torch.bfloat16, cache_dir=args.cache_dir
).to(encode_device).eval()
text_encoder.requires_grad_(False)
text_encoder_2 = T5EncoderModel.from_pretrained(
args.model_name, subfolder="text_encoder_2", torch_dtype=torch.bfloat16, cache_dir=args.cache_dir
).to(encode_device).eval()
text_encoder_2.requires_grad_(False)
vae_shift = vae.config.shift_factor
vae_scale = vae.config.scaling_factor
print(f" VAE config: shift_factor={vae_shift}, scaling_factor={vae_scale}")
# Load transformer on train_device
print(f" Loading Flux transformer on {train_device}...")
from diffusers import FluxTransformer2DModel
transformer = FluxTransformer2DModel.from_pretrained(
args.model_name, subfolder="transformer", torch_dtype=torch.bfloat16, cache_dir=args.cache_dir
)
# Check guidance
has_guidance = getattr(transformer.config, "guidance_embeds", False)
print(f" Model has guidance_embeds: {has_guidance}")
# LoRA - comprehensive target modules for Flux MMDiT
lora_target_modules = [
"attn.to_q", "attn.to_k", "attn.to_v", "attn.to_out.0",
"attn.add_k_proj", "attn.add_q_proj", "attn.add_v_proj", "attn.to_add_out",
"ff.net.0.proj", "ff.net.2",
"ff_context.net.0.proj", "ff_context.net.2",
]
lora_config = LoraConfig(
r=args.lora_rank,
lora_alpha=args.lora_alpha,
target_modules=lora_target_modules,
lora_dropout=0.0,
)
transformer = get_peft_model(transformer, lora_config)
# Load checkpoint weights if resuming
if resume_path:
adapter_path = resume_path / "adapter_model.safetensors"
if adapter_path.exists():
import safetensors.torch
state_dict = safetensors.torch.load_file(str(adapter_path))
set_peft_model_state_dict(transformer, state_dict)
print(f" Loaded LoRA weights from checkpoint")
else:
adapter_bin = resume_path / "adapter_model.bin"
if adapter_bin.exists():
state_dict = torch.load(str(adapter_bin), map_location="cpu")
set_peft_model_state_dict(transformer, state_dict)
print(f" Loaded LoRA weights from checkpoint")
transformer.to(train_device)
transformer.print_trainable_parameters()
transformer.train()
# Optimizer + scheduler
trainable_params = [p for p in transformer.parameters() if p.requires_grad]
optimizer = torch.optim.AdamW(trainable_params, lr=args.learning_rate, weight_decay=0.01, betas=(0.9, 0.999))
from diffusers.optimization import get_scheduler
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps,
num_training_steps=args.max_train_steps,
)
# Restore optimizer + scheduler state if resuming
if resume_step > 0 and resume_path:
training_state_path = resume_path / "training_state.pt"
if training_state_path.exists():
state = torch.load(str(training_state_path), map_location="cpu")
optimizer.load_state_dict(state["optimizer"])
lr_scheduler.load_state_dict(state["lr_scheduler"])
print(f" Restored optimizer + scheduler state from checkpoint")
else:
print(f" No training_state.pt found, fast-forwarding scheduler...")
for _ in range(resume_step):
lr_scheduler.step()
# Dataset
print(f" Loading dataset from {args.data_dir}")
train_dataset, num_shards = create_webdataset(args.data_dir, args.resolution, args.batch_size)
train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=None, num_workers=2, prefetch_factor=4
)
# Sample prompts for monitoring
sample_prompts = [
"a beautiful mountain landscape at sunset, 4k, highly detailed",
"a cute cat sitting on a windowsill, natural lighting",
"a futuristic city skyline at night with neon lights",
"portrait of a woman with flowers in her hair, oil painting style",
]
# Training loop
global_step = resume_step
accum_loss = 0.0
accum_grad_norm = 0.0
accum_count = 0
log_interval = 50
t0 = time.time()
print(f"\n === Training Config ===")
print(f" Model: {args.model_name}")
print(f" LoRA rank: {args.lora_rank}, alpha: {args.lora_alpha}, scaling: {args.lora_alpha/args.lora_rank:.2f}")
print(f" Batch size: {args.batch_size}, Grad accum: {args.gradient_accumulation}")
print(f" Effective batch: {args.batch_size * args.gradient_accumulation}")
print(f" LR: {args.learning_rate}, Scheduler: {args.lr_scheduler}, Warmup: {args.lr_warmup_steps}")
print(f" Weighting: {args.weighting_scheme}")
print(f" Guidance: {args.guidance_scale if has_guidance else 'N/A (Schnell)'}")
print(f" Encode: {encode_device}, Train: {train_device}")
print(f" Save every {args.save_steps} steps, Sample every {args.sample_steps} steps")
print(f" Starting from step {global_step}")
print(f" ========================\n")
optimizer.zero_grad()
while global_step < args.max_train_steps:
for batch in train_dataloader:
if global_step >= args.max_train_steps:
break
images = batch["image"].to(encode_device, dtype=torch.bfloat16)
captions = batch["caption"]
bs = images.shape[0]
# === Encode on encode_device ===
with torch.no_grad():
# VAE encode
latents = vae.encode(images).latent_dist.sample()
latents = (latents - vae_shift) * vae_scale
# latents shape: [B, 16, H/8, W/8]
_, num_channels, latent_h, latent_w = latents.shape
# Text encode - CLIP (pooled)
text_ids = tokenizer(
captions, padding="max_length", max_length=77,
truncation=True, return_tensors="pt"
).input_ids.to(encode_device)
pooled_prompt_embeds = text_encoder(text_ids, output_hidden_states=False).pooler_output
# Text encode - T5 (sequence)
text_ids_2 = tokenizer_2(
captions, padding="max_length", max_length=512,
truncation=True, return_tensors="pt"
).input_ids.to(encode_device)
encoder_hidden_states = text_encoder_2(text_ids_2)[0]
# === Move to train device ===
latents = latents.to(train_device)
pooled_prompt_embeds = pooled_prompt_embeds.to(train_device)
encoder_hidden_states = encoder_hidden_states.to(train_device)
# === Flow matching setup ===
noise = torch.randn_like(latents)
# Sample timesteps using density function
u = compute_density_for_timestep_sampling(
args.weighting_scheme, bs, args.logit_mean, args.logit_std
)
# u is in [0, 1], use as sigmas directly (linear schedule)
sigmas = u.to(device=train_device, dtype=torch.bfloat16)
sigmas_expand = sigmas.view(-1, 1, 1, 1)
# Noisy latents: linear interpolation
noisy_latents = (1.0 - sigmas_expand) * latents + sigmas_expand * noise
# Target: velocity = noise - clean
target = noise - latents
# === Pack latents for transformer ===
packed_noisy = pack_latents(noisy_latents, bs, num_channels, latent_h, latent_w)
packed_target = pack_latents(target, bs, num_channels, latent_h, latent_w)
# === Prepare positional IDs ===
# img_ids: spatial positions for packed patches
# packed dims are latent_h//2, latent_w//2
img_ids = prepare_latent_image_ids(
latent_h // 2, latent_w // 2, train_device, torch.bfloat16
)
# txt_ids: zeros for text tokens
txt_ids = torch.zeros(encoder_hidden_states.shape[1], 3, device=train_device, dtype=torch.bfloat16)
# === Timesteps for transformer (divide by 1000) ===
timesteps = (sigmas * 1000.0)
# === Guidance ===
guidance = None
if has_guidance:
guidance = torch.full((bs,), args.guidance_scale, device=train_device, dtype=torch.bfloat16)
# === Forward pass ===
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
model_pred = transformer(
hidden_states=packed_noisy,
timestep=timesteps / 1000,
guidance=guidance,
encoder_hidden_states=encoder_hidden_states,
pooled_projections=pooled_prompt_embeds,
img_ids=img_ids,
txt_ids=txt_ids,
return_dict=False,
)[0]
# === Loss computation in fp32 ===
weighting = compute_loss_weighting(args.weighting_scheme, sigmas)
# weighting shape: [B], need to expand for sequence dim
weighting = weighting.view(-1, 1, 1).to(model_pred.device)
loss = torch.mean(
(weighting * (model_pred.float() - packed_target.float()) ** 2).reshape(bs, -1),
dim=1,
).mean()
# NaN check
if torch.isnan(loss) or torch.isinf(loss):
print(f" WARNING: Invalid loss at step {global_step}, skipping batch", flush=True)
optimizer.zero_grad()
accum_count += 1
continue
scaled_loss = loss / args.gradient_accumulation
scaled_loss.backward()
accum_loss += loss.item()
accum_count += 1
# === Optimizer step ===
if accum_count % args.gradient_accumulation == 0:
grad_norm = torch.nn.utils.clip_grad_norm_(trainable_params, args.max_grad_norm)
accum_grad_norm += grad_norm.item()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
global_step += 1
# === Logging ===
if global_step % log_interval == 0:
elapsed = time.time() - t0
steps_done = global_step - resume_step
steps_per_sec = steps_done / elapsed if elapsed > 0 else 0
avg_loss = accum_loss / (log_interval * args.gradient_accumulation)
avg_grad = accum_grad_norm / log_interval
cur_lr = lr_scheduler.get_last_lr()[0]
print(
f" Step {global_step:6d} | "
f"Loss: {avg_loss:.4f} | "
f"GradNorm: {avg_grad:.3f} | "
f"LR: {cur_lr:.2e} | "
f"Speed: {steps_per_sec:.2f} st/s | "
f"Elapsed: {elapsed/3600:.1f}h",
flush=True,
)
accum_loss = 0.0
accum_grad_norm = 0.0
# === Save checkpoint ===
if global_step % args.save_steps == 0:
save_path = args.output_dir / f"checkpoint-{global_step}"
save_path.mkdir(parents=True, exist_ok=True)
transformer.save_pretrained(save_path)
# Save optimizer state for proper resume
torch.save({
"optimizer": optimizer.state_dict(),
"lr_scheduler": lr_scheduler.state_dict(),
"global_step": global_step,
}, save_path / "training_state.pt")
print(f" Saved checkpoint: {save_path}", flush=True)
# Cleanup old checkpoints (keep last 3)
all_ckpts = sorted(
[d for d in args.output_dir.iterdir() if d.is_dir() and d.name.startswith("checkpoint-")],
key=lambda p: int(p.name.split("-")[1]),
)
if len(all_ckpts) > 3:
for old_ckpt in all_ckpts[:-3]:
import shutil
shutil.rmtree(old_ckpt)
print(f" Removed old checkpoint: {old_ckpt.name}")
# === Generate samples ===
if global_step % args.sample_steps == 0:
print(f" Generating samples at step {global_step}...")
generate_samples(
transformer=transformer,
vae=vae,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
prompts=sample_prompts,
output_dir=args.output_dir,
global_step=global_step,
encode_device=encode_device,
train_device=train_device,
num_inference_steps=4,
guidance_scale=0.0,
)
# Final save
final_path = args.output_dir / "final"
final_path.mkdir(parents=True, exist_ok=True)
transformer.save_pretrained(final_path)
print(f" Training complete! Saved to {final_path}")
if __name__ == "__main__":
main()
|