| | import torch |
| | import torch.nn as nn |
| | import torch.optim as optim |
| | from torch.utils.data import Dataset, DataLoader |
| | from pathlib import Path |
| | import argparse |
| | from tqdm import tqdm |
| | from safetensors.torch import save_file, load_file |
| | from collections import deque |
| | from model import LocalSongModel |
| |
|
| | HARDCODED_TAGS = [1908] |
| | torch.set_float32_matmul_precision('high') |
| |
|
| | class LoRALinear(nn.Module): |
| | def __init__(self, original_linear: nn.Linear, rank: int = 8, alpha: float = 16.0): |
| | super().__init__() |
| | self.original_linear = original_linear |
| | self.rank = rank |
| | self.alpha = alpha |
| | self.scaling = alpha / rank |
| |
|
| | self.lora_A = nn.Parameter(torch.zeros(original_linear.in_features, rank)) |
| | self.lora_B = nn.Parameter(torch.zeros(rank, original_linear.out_features)) |
| |
|
| | nn.init.kaiming_uniform_(self.lora_A, a=5**0.5) |
| | nn.init.zeros_(self.lora_B) |
| |
|
| | self.original_linear.weight.requires_grad = False |
| | if self.original_linear.bias is not None: |
| | self.original_linear.bias.requires_grad = False |
| |
|
| | def forward(self, x): |
| | result = self.original_linear(x) |
| | lora_out = (x @ self.lora_A @ self.lora_B) * self.scaling |
| | return result + lora_out |
| |
|
| | def inject_lora(model: LocalSongModel, rank: int = 8, alpha: float = 16.0, target_modules=['qkv', 'proj', 'w1', 'w2', 'w3', 'q_proj', 'kv_proj'], device=None): |
| | """Inject LoRA layers into the model.""" |
| |
|
| | lora_modules = [] |
| |
|
| | if device is None: |
| | device = next(model.parameters()).device |
| |
|
| | for name, module in model.named_modules(): |
| |
|
| | if isinstance(module, nn.Linear): |
| |
|
| | if any(target in name for target in target_modules): |
| |
|
| | *parent_path, attr_name = name.split('.') |
| | parent = model |
| | for p in parent_path: |
| | parent = getattr(parent, p) |
| |
|
| | lora_layer = LoRALinear(module, rank=rank, alpha=alpha) |
| |
|
| | lora_layer.lora_A.data = lora_layer.lora_A.data.to(device) |
| | lora_layer.lora_B.data = lora_layer.lora_B.data.to(device) |
| | setattr(parent, attr_name, lora_layer) |
| | lora_modules.append(name) |
| |
|
| | print(f"Injected LoRA into {len(lora_modules)} layers:") |
| | for name in lora_modules[:5]: |
| | print(f" - {name}") |
| | if len(lora_modules) > 5: |
| | print(f" ... and {len(lora_modules) - 5} more") |
| |
|
| | return model |
| |
|
| | def get_lora_parameters(model): |
| | """Extract only LoRA parameters for optimization.""" |
| | lora_params = [] |
| | for module in model.modules(): |
| | if isinstance(module, LoRALinear): |
| | lora_params.extend([module.lora_A, module.lora_B]) |
| | return lora_params |
| |
|
| | def save_lora_weights(model, output_path): |
| | """Save LoRA weights to a safetensors file.""" |
| | lora_state_dict = {} |
| |
|
| | for name, module in model.named_modules(): |
| | if isinstance(module, LoRALinear): |
| | lora_state_dict[f"{name}.lora_A"] = module.lora_A |
| | lora_state_dict[f"{name}.lora_B"] = module.lora_B |
| |
|
| | save_file(lora_state_dict, output_path) |
| | print(f"Saved {len(lora_state_dict)} LoRA parameters to {output_path}") |
| |
|
| | class LatentDataset(Dataset): |
| | """Dataset for pre-encoded latents.""" |
| |
|
| | def __init__(self, latents_dir: str): |
| | self.latents_dir = Path(latents_dir) |
| |
|
| | self.latent_files = sorted(list(self.latents_dir.glob("*.pt"))) |
| |
|
| | if len(self.latent_files) == 0: |
| | raise ValueError(f"No .pt files found in {latents_dir}") |
| |
|
| | print(f"Found {len(self.latent_files)} latent files") |
| |
|
| | def __len__(self): |
| | return len(self.latent_files) |
| |
|
| | def __getitem__(self, idx): |
| | latent = torch.load(self.latent_files[idx]) |
| |
|
| | if latent.ndim == 3: |
| | latent = latent.unsqueeze(0) |
| |
|
| | return latent |
| |
|
| | class RectifiedFlow: |
| | """Simplified rectified flow matching.""" |
| |
|
| | def __init__(self, model): |
| | self.model = model |
| |
|
| | def forward(self, x, cond): |
| | """Compute flow matching loss.""" |
| | b = x.size(0) |
| |
|
| | nt = torch.randn((b,), device=x.device) |
| | t = torch.sigmoid(nt) |
| |
|
| | texp = t.view([b, *([1] * len(x.shape[1:]))]) |
| | z1 = torch.randn_like(x) |
| | zt = (1 - texp) * x + texp * z1 |
| |
|
| | vtheta = self.model(zt, t, cond) |
| |
|
| | target = z1 - x |
| | loss = ((vtheta - target) ** 2).mean() |
| |
|
| | return loss |
| |
|
| | def collate_fn(batch, subsection_length=1024): |
| | """Custom collate function to sample random subsections.""" |
| | sampled_latents = [] |
| |
|
| | for latent in batch: |
| | if latent.ndim == 3: |
| | latent = latent.unsqueeze(0) |
| |
|
| | _, _, _, width = latent.shape |
| |
|
| | if width < subsection_length: |
| | |
| | pad_amount = subsection_length - width |
| | latent = torch.nn.functional.pad(latent, (0, pad_amount), mode='constant', value=0) |
| | else: |
| | |
| | max_start = width - subsection_length |
| | start_idx = torch.randint(0, max_start + 1, (1,)).item() |
| | latent = latent[:, :, :, start_idx:start_idx + subsection_length] |
| |
|
| | sampled_latents.append(latent.squeeze(0)) |
| |
|
| | batch_latents = torch.stack(sampled_latents) |
| |
|
| | batch_tags = [HARDCODED_TAGS] * len(batch) |
| |
|
| | return batch_latents, batch_tags |
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser(description='LoRA training for LocalSong model with embedding training') |
| |
|
| | parser.add_argument('--latents_dir', type=str, required=True, |
| | help='Directory containing VAE-encoded latents (.pt files)') |
| | |
| | parser.add_argument('--checkpoint', type=str, default='checkpoints/checkpoint_461260.safetensors', |
| | help='Path to base model checkpoint') |
| | parser.add_argument('--lora_rank', type=int, default=16, |
| | help='LoRA rank') |
| | parser.add_argument('--lora_alpha', type=float, default=16, |
| | help='LoRA alpha (scaling factor)') |
| | parser.add_argument('--batch_size', type=int, default=16, |
| | help='Batch size') |
| | parser.add_argument('--lr', type=float, default=2e-4, |
| | help='Learning rate') |
| | parser.add_argument('--steps', type=int, default=1500, |
| | help='Number of training steps') |
| | parser.add_argument('--subsection_length', type=int, default=512, |
| | help='Latent subsection length') |
| | parser.add_argument('--output', type=str, default='lora.safetensors', |
| | help='Output path for LoRA weights') |
| | parser.add_argument('--save_every', type=int, default=500, |
| | help='Save checkpoint every N steps') |
| |
|
| | args = parser.parse_args() |
| |
|
| | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| | print(f"Using device: {device}") |
| |
|
| | print(f"Using hardcoded tags: {HARDCODED_TAGS}") |
| |
|
| | print(f"Loading base model from {args.checkpoint}") |
| | model = LocalSongModel( |
| | in_channels=8, |
| | num_groups=16, |
| | hidden_size=1024, |
| | decoder_hidden_size=2048, |
| | num_blocks=36, |
| | patch_size=(16, 1), |
| | num_classes=2304, |
| | max_tags=8, |
| | ) |
| |
|
| | print(f"Loading checkpoint from {args.checkpoint}") |
| | state_dict = load_file(args.checkpoint) |
| | model.load_state_dict(state_dict, strict=True) |
| | print("Base model loaded") |
| |
|
| | model = model.to(device) |
| | model = inject_lora(model, rank=args.lora_rank, alpha=args.lora_alpha, device=device) |
| |
|
| | model.train() |
| |
|
| | lora_params = get_lora_parameters(model) |
| | optimizer = optim.Adam(lora_params, lr=args.lr) |
| | print(f"Training {len(lora_params)} LoRA parameters") |
| |
|
| | dataset = LatentDataset(args.latents_dir) |
| | dataloader = DataLoader( |
| | dataset, |
| | batch_size=args.batch_size, |
| | shuffle=True, |
| | num_workers=0, |
| | collate_fn=lambda batch: collate_fn(batch, args.subsection_length) |
| | ) |
| |
|
| | rf = RectifiedFlow(model) |
| |
|
| | print("\nStarting training...") |
| | step = 0 |
| | pbar = tqdm(total=args.steps, desc="Training") |
| |
|
| | loss_history = deque(maxlen=50) |
| |
|
| | while step < args.steps: |
| | for batch_latents, batch_tags in dataloader: |
| | batch_latents = batch_latents.to(device) |
| |
|
| | optimizer.zero_grad() |
| | loss = rf.forward(batch_latents, batch_tags) |
| |
|
| | loss.backward() |
| | torch.nn.utils.clip_grad_norm_(lora_params, 1.0) |
| | optimizer.step() |
| |
|
| | |
| | loss_history.append(loss.item()) |
| | avg_loss = sum(loss_history) / len(loss_history) |
| |
|
| | pbar.set_postfix({'loss': f'{avg_loss:.4f}'}) |
| | pbar.update(1) |
| | step += 1 |
| |
|
| | if step % args.save_every == 0: |
| | save_path = args.output.replace('.safetensors', f'_step{step}.safetensors') |
| | save_lora_weights(model, save_path) |
| |
|
| | if step >= args.steps: |
| | break |
| |
|
| | save_lora_weights(model, args.output) |
| | print(f"\nTraining complete! LoRA weights saved to {args.output}") |
| |
|
| | if __name__ == '__main__': |
| | main() |
| |
|