| | import argparse
|
| | import os
|
| | from pathlib import Path
|
| | import torch
|
| | from torch.utils.data import Dataset, DataLoader
|
| | import numpy as np
|
| | from accelerate import Accelerator
|
| | from transformers import AutoModelForCausalLM, get_linear_schedule_with_warmup
|
| | from torch.optim import AdamW
|
| | from tqdm import tqdm
|
| | import gc
|
| | import traceback
|
| | import matplotlib.pyplot as plt
|
| | from anticipation.vocab import ANTICIPATE, AUTOREGRESS
|
| |
|
| |
|
| | def print_gpu_memory_stats():
|
| | if torch.cuda.is_available():
|
| | for i in range(torch.cuda.device_count()):
|
| | print(f"GPU {i} memory allocated: {torch.cuda.memory_allocated(i) / 1024**2:.2f} MB")
|
| | print(f"GPU {i} memory reserved: {torch.cuda.memory_reserved(i) / 1024**2:.2f} MB")
|
| | print(f"GPU {i} max memory allocated: {torch.cuda.max_memory_allocated(i) / 1024**2:.2f} MB")
|
| |
|
| |
|
| | def check_model_for_nans(model):
|
| | for name, param in model.named_parameters():
|
| | if torch.isnan(param).any():
|
| | print(f"NaN detected in parameter {name}")
|
| | return True
|
| | return False
|
| |
|
| |
|
| | if torch.cuda.is_available():
|
| | device = torch.device("cuda")
|
| | device_count = torch.cuda.device_count()
|
| | print(f"✓ CUDA is available with {device_count} device(s)")
|
| | for i in range(device_count):
|
| | device_name = torch.cuda.get_device_name(i)
|
| | print(f" Device {i}: {device_name}")
|
| | props = torch.cuda.get_device_properties(i)
|
| | print(f" - Total memory: {props.total_memory / 1024**3:.2f} GB")
|
| | print(f" - CUDA capability: {props.major}.{props.minor}")
|
| | else:
|
| | device = torch.device("cpu")
|
| | print("✗ CUDA is not available! Training will be much slower on CPU.")
|
| |
|
| |
|
| | print(f"Using device: {device}")
|
| | print(f"PyTorch version: {torch.__version__}")
|
| | print(f"CUDA version: {torch.version.cuda}")
|
| |
|
| | class SequencePackedDataset(Dataset):
|
| | def __init__(self, file_path, context_length=1024, max_packed_sequences=4):
|
| | """Load data from tokenized file and implement sequence packing
|
| |
|
| | Args:
|
| | file_path: Path to the tokenized data file
|
| | context_length: Maximum context length (default 1024)
|
| | max_packed_sequences: Maximum number of sequences to pack together (default 4)
|
| | """
|
| | from anticipation.vocab import SEPARATOR, AUTOREGRESS, ANTICIPATE
|
| |
|
| |
|
| | individual_sequences = []
|
| | with open(file_path, 'r') as f:
|
| | for line in f:
|
| | tokens = list(map(int, line.strip().split()))
|
| | individual_sequences.append(tokens)
|
| |
|
| | print(f"Loaded {len(individual_sequences)} individual sequences")
|
| |
|
| |
|
| | self.packed_sequences = []
|
| | self.attention_masks = []
|
| |
|
| |
|
| | self.total_packed = 0
|
| | self.avg_sequences_per_pack = 0
|
| | sequences_per_pack = []
|
| |
|
| |
|
| | import random
|
| | random.shuffle(individual_sequences)
|
| |
|
| |
|
| | current_packed = []
|
| | current_positions = []
|
| |
|
| | for sequence in individual_sequences:
|
| |
|
| | control_flag = sequence[0]
|
| | assert control_flag in [AUTOREGRESS, ANTICIPATE], f"Invalid control flag: {control_flag}"
|
| |
|
| |
|
| | sequence_content = sequence[1:]
|
| |
|
| |
|
| |
|
| | if len(current_packed) > 0 and (len(current_packed) + 3 + len(sequence_content) > context_length or
|
| | len(sequences_per_pack) >= max_packed_sequences):
|
| |
|
| | if len(current_packed) > 0:
|
| |
|
| | attention_mask = torch.zeros(context_length, dtype=torch.long)
|
| | for start, end in current_positions:
|
| | attention_mask[start:end] = 1
|
| |
|
| |
|
| | if len(current_packed) < context_length:
|
| | padding_length = context_length - len(current_packed)
|
| | current_packed.extend([SEPARATOR] * padding_length)
|
| |
|
| |
|
| | self.packed_sequences.append(torch.tensor(current_packed[:context_length], dtype=torch.long))
|
| | self.attention_masks.append(attention_mask)
|
| | sequences_per_pack.append(len(current_positions))
|
| | self.total_packed += 1
|
| |
|
| |
|
| | current_packed = []
|
| | current_positions = []
|
| |
|
| |
|
| | start_pos = len(current_packed)
|
| | if len(current_packed) > 0:
|
| |
|
| | current_packed.extend([SEPARATOR, SEPARATOR, SEPARATOR])
|
| | start_pos += 3
|
| |
|
| |
|
| | current_packed.append(control_flag)
|
| | current_packed.extend(sequence_content)
|
| | end_pos = len(current_packed)
|
| |
|
| |
|
| | current_positions.append((start_pos, end_pos))
|
| |
|
| |
|
| | if len(current_packed) > 0:
|
| | attention_mask = torch.zeros(context_length, dtype=torch.long)
|
| | for start, end in current_positions:
|
| | attention_mask[start:end] = 1
|
| |
|
| |
|
| | if len(current_packed) < context_length:
|
| | padding_length = context_length - len(current_packed)
|
| | current_packed.extend([SEPARATOR] * padding_length)
|
| |
|
| |
|
| | self.packed_sequences.append(torch.tensor(current_packed[:context_length], dtype=torch.long))
|
| | self.attention_masks.append(attention_mask)
|
| | sequences_per_pack.append(len(current_positions))
|
| | self.total_packed += 1
|
| |
|
| |
|
| | if sequences_per_pack:
|
| | self.avg_sequences_per_pack = sum(sequences_per_pack) / len(sequences_per_pack)
|
| |
|
| | print(f"Created {len(self.packed_sequences)} packed sequences")
|
| | print(f"Average sequences per pack: {self.avg_sequences_per_pack:.2f}")
|
| |
|
| | def __len__(self):
|
| | return len(self.packed_sequences)
|
| |
|
| | def __getitem__(self, idx):
|
| | return {
|
| | "input_ids": self.packed_sequences[idx],
|
| | "attention_mask": self.attention_masks[idx],
|
| | "labels": self.packed_sequences[idx],
|
| | }
|
| |
|
| | def collate_packed_sequences(batch):
|
| | """Collate function for packed sequences that includes attention masks"""
|
| | input_ids = torch.stack([item["input_ids"] for item in batch])
|
| | attention_masks = torch.stack([item["attention_mask"] for item in batch])
|
| | labels = torch.stack([item["labels"] for item in batch])
|
| | return {
|
| | "input_ids": input_ids,
|
| | "attention_mask": attention_masks,
|
| | "labels": labels
|
| | }
|
| |
|
| | def evaluate_model(model, dataloader, accelerator):
|
| | """Calculate validation loss on a dataset"""
|
| | model.eval()
|
| | total_loss = 0
|
| | total_samples = 0
|
| |
|
| | with torch.no_grad():
|
| | for batch in tqdm(dataloader, desc="Evaluating", leave=False):
|
| | outputs = model(**batch)
|
| | loss = outputs.loss
|
| |
|
| |
|
| | batch_size = batch["input_ids"].size(0)
|
| |
|
| |
|
| | total_loss += loss.item() * batch_size
|
| | total_samples += batch_size
|
| |
|
| |
|
| | return total_loss / total_samples
|
| |
|
| | def plot_losses(train_losses, val_losses, validation_steps, output_dir):
|
| | """
|
| | Plot training and validation losses and save the figure
|
| |
|
| | Args:
|
| | train_losses (list): Training loss history
|
| | val_losses (list): Validation loss history
|
| | validation_steps (list): Steps at which validation was performed
|
| | output_dir (Path): Directory to save the plot
|
| | """
|
| | plt.figure(figsize=(10, 6))
|
| |
|
| |
|
| | steps = list(range(1, len(train_losses) + 1))
|
| | plt.plot(steps, train_losses, label='Training Loss', alpha=0.7, color='blue')
|
| |
|
| |
|
| | plt.plot(validation_steps, val_losses, label='Validation Loss',
|
| | linestyle='--', marker='o', markersize=5, color='red')
|
| |
|
| | plt.xlabel('Steps (x10)')
|
| | plt.ylabel('Loss')
|
| | plt.title('Training and Validation Loss')
|
| | plt.legend()
|
| | plt.grid(True, alpha=0.3)
|
| |
|
| |
|
| | plot_path = output_dir / "loss_plot.png"
|
| | plt.savefig(plot_path)
|
| | plt.close()
|
| |
|
| | print(f"Loss plot saved to {plot_path}")
|
| |
|
| | def main():
|
| | parser = argparse.ArgumentParser()
|
| | parser.add_argument('--data_file', type=Path, default=Path('./data/train.txt'))
|
| | parser.add_argument('--val_file', type=Path, default=Path('./data/test.txt'))
|
| | parser.add_argument('--model_name', type=str, default='stanford-crfm/music-small-800k')
|
| | parser.add_argument('--output_dir', type=Path, default=Path('./fine_tuned'))
|
| | parser.add_argument('--batch_size', type=int, default=8)
|
| | parser.add_argument('--val_batch_size', type=int, default=16)
|
| | parser.add_argument('--gradient_accumulation_steps', type=int, default=32)
|
| | parser.add_argument('--learning_rate', type=float, default=3e-5)
|
| | parser.add_argument('--max_steps', type=int, default=3500)
|
| | parser.add_argument('--save_steps', type=int, default=500)
|
| | parser.add_argument('--eval_steps', type=int, default=100)
|
| | parser.add_argument('--warmup_steps', type=int, default=500)
|
| | parser.add_argument('--force_cpu', action='store_true', help='Force CPU usage even if GPU is available')
|
| | parser.add_argument('--reduce_memory', action='store_true', help='Use memory-saving techniques')
|
| | parser.add_argument('--context_length', type=int, default=1024, help='Maximum context length')
|
| | parser.add_argument('--max_packed_sequences', type=int, default=4,
|
| | help='Maximum number of sequences to pack together (set to 1 to disable packing)')
|
| | args = parser.parse_args()
|
| |
|
| |
|
| | global device
|
| | if args.force_cpu:
|
| | device = torch.device("cpu")
|
| | print("Forcing CPU usage as requested")
|
| |
|
| | print(f"Effective batch size: {args.batch_size * args.gradient_accumulation_steps}")
|
| | print(f"Final device confirmation: {device}")
|
| |
|
| | try:
|
| |
|
| |
|
| | mixed_precision = 'bf16' if torch.cuda.is_available() and not args.force_cpu else 'no'
|
| | print(f"Mixed precision mode: {mixed_precision}")
|
| |
|
| | accelerator = Accelerator(
|
| | gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| | cpu=args.force_cpu,
|
| | mixed_precision=mixed_precision,
|
| | )
|
| |
|
| |
|
| | os.makedirs(args.output_dir, exist_ok=True)
|
| |
|
| |
|
| | print("Initial GPU memory stats:")
|
| | print_gpu_memory_stats()
|
| |
|
| |
|
| | print(f"Loading training dataset from {args.data_file}...")
|
| | if args.max_packed_sequences > 1:
|
| | print(f"Using sequence packing with max {args.max_packed_sequences} sequences per pack")
|
| | train_dataset = SequencePackedDataset(
|
| | args.data_file,
|
| | context_length=args.context_length,
|
| | max_packed_sequences=args.max_packed_sequences
|
| | )
|
| | collate_fn_train = collate_packed_sequences
|
| | else:
|
| | print("Sequence packing disabled - using single sequences")
|
| |
|
| | from anticipation.vocab import SEPARATOR
|
| | individual_sequences = []
|
| | with open(args.data_file, 'r') as f:
|
| | for line in f:
|
| | tokens = list(map(int, line.strip().split()))
|
| | individual_sequences.append(torch.tensor(tokens, dtype=torch.long))
|
| |
|
| | class TokenizedDataset(Dataset):
|
| | def __init__(self, sequences):
|
| | self.sequences = sequences
|
| | self.sequence_length = len(self.sequences[0]) if self.sequences else 0
|
| | print(f"Loaded {len(self.sequences)} sequences with length {self.sequence_length}")
|
| |
|
| | def __len__(self):
|
| | return len(self.sequences)
|
| |
|
| | def __getitem__(self, idx):
|
| | tokens = self.sequences[idx]
|
| | return {"input_ids": tokens, "labels": tokens}
|
| |
|
| | train_dataset = TokenizedDataset(individual_sequences)
|
| |
|
| | def collate_fn_train(batch):
|
| | input_ids = torch.stack([item["input_ids"] for item in batch])
|
| | labels = torch.stack([item["labels"] for item in batch])
|
| | return {"input_ids": input_ids, "labels": labels}
|
| |
|
| | train_dataloader = DataLoader(
|
| | train_dataset,
|
| | batch_size=args.batch_size,
|
| | shuffle=True,
|
| | collate_fn=collate_fn_train,
|
| | pin_memory=torch.cuda.is_available() and not args.force_cpu,
|
| | num_workers=0,
|
| | )
|
| |
|
| |
|
| | print(f"Loading validation dataset from {args.val_file}...")
|
| | if args.max_packed_sequences > 1:
|
| | val_dataset = SequencePackedDataset(
|
| | args.val_file,
|
| | context_length=args.context_length,
|
| | max_packed_sequences=args.max_packed_sequences
|
| | )
|
| | collate_fn_val = collate_packed_sequences
|
| | else:
|
| |
|
| | val_sequences = []
|
| | with open(args.val_file, 'r') as f:
|
| | for line in f:
|
| | tokens = list(map(int, line.strip().split()))
|
| | val_sequences.append(torch.tensor(tokens, dtype=torch.long))
|
| |
|
| | val_dataset = TokenizedDataset(val_sequences)
|
| | collate_fn_val = collate_fn_train
|
| |
|
| | val_dataloader = DataLoader(
|
| | val_dataset,
|
| | batch_size=args.val_batch_size,
|
| | shuffle=False,
|
| | collate_fn=collate_fn_val,
|
| | pin_memory=torch.cuda.is_available() and not args.force_cpu,
|
| | num_workers=0,
|
| | )
|
| |
|
| |
|
| | print(f"Loading model {args.model_name}...")
|
| | model_kwargs = {
|
| | "trust_remote_code": True,
|
| | "use_cache": False,
|
| | }
|
| |
|
| | if args.reduce_memory and torch.cuda.is_available():
|
| | print("Using memory reduction techniques...")
|
| |
|
| | model_kwargs.update({
|
| | "torch_dtype": torch.bfloat16 if torch.cuda.is_available() else torch.float32,
|
| | "low_cpu_mem_usage": True,
|
| | })
|
| |
|
| | try:
|
| | model = AutoModelForCausalLM.from_pretrained(
|
| | args.model_name,
|
| | **model_kwargs
|
| | )
|
| | except Exception as e:
|
| | print(f"Error loading model with advanced options: {e}")
|
| | print("Trying with basic options...")
|
| | model = AutoModelForCausalLM.from_pretrained(
|
| | args.model_name,
|
| | trust_remote_code=True,
|
| | use_cache=False
|
| | )
|
| |
|
| |
|
| | print("GPU memory after loading model:")
|
| | print_gpu_memory_stats()
|
| |
|
| |
|
| | model = model.to(device)
|
| | print(f"Model moved to: {next(model.parameters()).device}")
|
| |
|
| |
|
| |
|
| | optimizer = AdamW(
|
| | model.parameters(),
|
| | lr=args.learning_rate,
|
| | eps=1e-6,
|
| | weight_decay=0.01,
|
| | betas=(0.9, 0.999),
|
| | )
|
| |
|
| |
|
| | model, optimizer, train_dataloader = accelerator.prepare(model, optimizer, train_dataloader)
|
| | val_dataloader = accelerator.prepare_data_loader(val_dataloader)
|
| | print(f"After accelerator preparation, model device: {next(model.parameters()).device}")
|
| |
|
| |
|
| | scheduler = get_linear_schedule_with_warmup(
|
| | optimizer=optimizer,
|
| | num_warmup_steps=args.warmup_steps,
|
| | num_training_steps=args.max_steps,
|
| | )
|
| |
|
| |
|
| | print("GPU memory before training:")
|
| | print_gpu_memory_stats()
|
| |
|
| |
|
| | torch.autograd.set_detect_anomaly(False)
|
| |
|
| |
|
| | torch.backends.cudnn.deterministic = False
|
| | torch.backends.cudnn.benchmark = True
|
| |
|
| | if torch.cuda.is_available():
|
| | print("Clearing CUDA cache before training")
|
| | torch.cuda.empty_cache()
|
| | torch.cuda.set_device(0)
|
| |
|
| |
|
| | print("Starting training...")
|
| | model.train()
|
| | completed_steps = 0
|
| | step = 0
|
| |
|
| |
|
| | train_losses = []
|
| | val_losses = []
|
| | validation_steps = []
|
| |
|
| |
|
| | progress_bar = tqdm(total=args.max_steps, desc="Training", disable=False)
|
| |
|
| | try:
|
| | while completed_steps < args.max_steps:
|
| | for batch in train_dataloader:
|
| | try:
|
| | with accelerator.accumulate(model):
|
| |
|
| | outputs = model(**batch)
|
| | loss = outputs.loss
|
| |
|
| |
|
| | if torch.isnan(loss).any() or torch.isinf(loss).any():
|
| | print(f"WARNING: NaN or Inf loss detected: {loss.item()}")
|
| |
|
| | optimizer.zero_grad()
|
| | continue
|
| |
|
| |
|
| | accelerator.backward(loss)
|
| |
|
| |
|
| | if accelerator.sync_gradients:
|
| |
|
| | accelerator.clip_grad_norm_(model.parameters(), max_norm=0.5)
|
| |
|
| |
|
| | has_nan_grads = False
|
| | for name, param in model.named_parameters():
|
| | if param.grad is not None and torch.isnan(param.grad).any():
|
| | print(f"NaN gradient detected in {name}")
|
| | has_nan_grads = True
|
| | break
|
| |
|
| | if has_nan_grads:
|
| | print("Skipping update due to NaN gradients")
|
| | optimizer.zero_grad()
|
| | continue
|
| |
|
| |
|
| | optimizer.step()
|
| | scheduler.step()
|
| | optimizer.zero_grad()
|
| |
|
| |
|
| | completed_steps += 1
|
| | progress_bar.update(1)
|
| |
|
| |
|
| | if completed_steps % 10 == 0:
|
| |
|
| | train_losses.append(loss.item())
|
| |
|
| |
|
| | print(f"Step: {completed_steps}/{args.max_steps}, Loss: {loss.item():.4f}, "
|
| | f"LR: {scheduler.get_last_lr()[0]:.8e}")
|
| |
|
| |
|
| | if check_model_for_nans(model):
|
| | print("NaN parameters detected in model! Training may be unstable.")
|
| |
|
| |
|
| | if completed_steps % 100 == 0:
|
| | print_gpu_memory_stats()
|
| |
|
| |
|
| | if completed_steps % args.eval_steps == 0:
|
| | print(f"\nRunning validation at step {completed_steps}...")
|
| | val_loss = evaluate_model(model, val_dataloader, accelerator)
|
| | validation_steps.append(completed_steps // 10)
|
| | val_losses.append(val_loss)
|
| | print(f"Validation Loss: {val_loss:.4f}")
|
| |
|
| |
|
| | model.train()
|
| |
|
| |
|
| | if torch.cuda.is_available():
|
| | torch.cuda.empty_cache()
|
| | gc.collect()
|
| |
|
| |
|
| | if completed_steps % args.save_steps == 0:
|
| | checkpoint_dir = args.output_dir / f"checkpoint-{completed_steps}"
|
| | os.makedirs(checkpoint_dir, exist_ok=True)
|
| |
|
| |
|
| | unwrapped_model = accelerator.unwrap_model(model)
|
| | unwrapped_model.save_pretrained(
|
| | checkpoint_dir,
|
| | is_main_process=accelerator.is_main_process,
|
| | save_function=accelerator.save,
|
| | )
|
| | print(f"Saved checkpoint to {checkpoint_dir}")
|
| |
|
| |
|
| | np.savez(
|
| | checkpoint_dir / "losses.npz",
|
| | train_losses=np.array(train_losses),
|
| | val_losses=np.array(val_losses),
|
| | validation_steps=np.array(validation_steps)
|
| | )
|
| |
|
| |
|
| | plot_losses(train_losses, val_losses, validation_steps, checkpoint_dir)
|
| |
|
| |
|
| | if torch.cuda.is_available():
|
| | torch.cuda.empty_cache()
|
| | gc.collect()
|
| |
|
| |
|
| | if not accelerator.sync_gradients:
|
| | optimizer.zero_grad()
|
| |
|
| |
|
| | if completed_steps >= args.max_steps:
|
| | break
|
| |
|
| | except RuntimeError as e:
|
| | if "CUDA out of memory" in str(e):
|
| | print(f"CUDA OOM error! Current batch size: {args.batch_size}")
|
| | print(f"Current memory usage:")
|
| | print_gpu_memory_stats()
|
| | print("Consider reducing batch size or model size.")
|
| | print(f"Error details: {str(e)}")
|
| | raise
|
| | elif "nan" in str(e).lower() or "inf" in str(e).lower():
|
| | print(f"NaN/Inf error: {str(e)}")
|
| | print("Trying to recover by skipping this batch...")
|
| | optimizer.zero_grad()
|
| | continue
|
| | else:
|
| | print(f"Runtime error: {str(e)}")
|
| | print(traceback.format_exc())
|
| | raise
|
| |
|
| | except Exception as e:
|
| | print(f"Error during training: {e}")
|
| | print(traceback.format_exc())
|
| | raise
|
| | finally:
|
| |
|
| | progress_bar.close()
|
| |
|
| |
|
| | try:
|
| |
|
| | print("\nRunning final validation...")
|
| | final_val_loss = evaluate_model(model, val_dataloader, accelerator)
|
| | validation_steps.append(completed_steps // 10)
|
| | val_losses.append(final_val_loss)
|
| | print(f"Final validation Loss: {final_val_loss:.4f}")
|
| |
|
| |
|
| | final_dir = args.output_dir / "final"
|
| | os.makedirs(final_dir, exist_ok=True)
|
| | unwrapped_model = accelerator.unwrap_model(model)
|
| | unwrapped_model.save_pretrained(
|
| | final_dir,
|
| | is_main_process=accelerator.is_main_process,
|
| | save_function=accelerator.save,
|
| | )
|
| | print(f"Saved final model to {final_dir}")
|
| |
|
| |
|
| | np.savez(
|
| | final_dir / "losses.npz",
|
| | train_losses=np.array(train_losses),
|
| | val_losses=np.array(val_losses),
|
| | validation_steps=np.array(validation_steps)
|
| | )
|
| |
|
| |
|
| | plot_losses(train_losses, val_losses, validation_steps, final_dir)
|
| |
|
| | except Exception as save_error:
|
| | print(f"Error saving final model or generating plot: {save_error}")
|
| |
|
| | except Exception as setup_error:
|
| | print(f"Error in setup: {setup_error}")
|
| | print(traceback.format_exc())
|
| |
|
| | if __name__ == "__main__":
|
| | main() |