| """ |
| Utilities for checkpointing learning dynamics-related states (i.e. activations, weights, grads, etc.) |
| |
| We save the learning dynamics states in a subdirectory of the checkpointing directory. |
| """ |
|
|
| import os |
| import re |
| from typing import Dict, Optional |
|
|
| import deepspeed |
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from datasets import Dataset |
| from huggingface_hub import upload_folder |
| from lightning.fabric import Fabric |
| from lightning.fabric.strategies import DeepSpeedStrategy |
| from lightning.fabric.utilities.rank_zero import rank_zero_only |
| from torch.nn import functional as F |
| from torch.utils.data import DataLoader |
| from transformers import PreTrainedTokenizerBase |
|
|
| from src.config import CheckpointingConfig |
| from src.config.checkpointing_config import LearningDynamicsCheckpointingConfig |
| from src.training.utils.initialization import initialize_model |
| from src.training.utils.io import use_backoff |
|
|
|
|
| |
| class DummyOptimizer(optim.Optimizer): |
| def __init__(self, params): |
| super().__init__(params, defaults={}) |
|
|
|
|
| class CheckpointStateExtractor: |
| """ |
| Class to extract and save the states of a model at a given checkpoint step for learning |
| dynamics research. |
| """ |
|
|
| def __init__( |
| self, |
| learning_dynamics_config: LearningDynamicsCheckpointingConfig, |
| fabric: Fabric, |
| model: nn.Module, |
| ): |
| self.learning_dynamics_config = learning_dynamics_config |
| self.fabric = fabric |
| self.model = model |
|
|
| def extract_states(self, dataloader, compute_gradients: bool = False): |
| """Extracts model states (activations, weights, and optionally gradients). |
| |
| Given a dataloader, this function will perform a forward pass of the model on each batch, |
| and save the activations and weights at each layer. If compute_gradients is True, it will |
| also compute the gradients of the model parameters. |
| |
| Args: |
| dataloader: The dataloader containing the dataset to extract states from. |
| compute_gradients: Whether to compute the gradients of the model parameters. |
| |
| Returns: |
| A dictionary containing the activations, weights, and optionally gradients of the model. |
| """ |
| checkpoint_activations = {} |
| checkpoint_weights = {} |
|
|
| |
| |
| |
| forward_hooks = self._setup_forward_hooks( |
| checkpoint_activations, |
| checkpoint_weights, |
| ) |
|
|
| |
| |
| |
| |
| |
|
|
| for sub_batch in dataloader: |
| _input_ids = torch.tensor(sub_batch["input_ids"], device=self.fabric.device) |
|
|
| if compute_gradients: |
| if "labels" in sub_batch: |
| input_ids = _input_ids |
| labels = torch.tensor( |
| sub_batch["labels"], device=self.fabric.device |
| ) |
| else: |
| input_ids = _input_ids[:, :-1] |
| labels = _input_ids[:, 1:] |
| else: |
| input_ids = _input_ids |
| labels = None |
|
|
| if labels is None: |
| |
| with torch.no_grad(): |
| _ = self.model(input_ids) |
| else: |
| |
| |
| outputs, _ = self.model(input_ids) |
| outputs = outputs.transpose(1, 2) |
| loss = F.cross_entropy(outputs, labels) |
| self.fabric.backward(loss, model=self.model) |
|
|
| |
| |
| |
| for hook in forward_hooks: |
| hook.remove() |
|
|
| |
| |
| |
| |
| |
|
|
| layer_suffixes = self.learning_dynamics_config.layer_suffixes |
| checkpoint_gradients = {} |
| if compute_gradients: |
| for name, param in self.model.named_parameters(): |
| |
| if ( |
| any(layer_suffix in name for layer_suffix in layer_suffixes) |
| and "weight" in name |
| ): |
| if isinstance(self.fabric.strategy, DeepSpeedStrategy): |
| _grad = deepspeed.utils.safe_get_full_grad(param) |
| else: |
| _grad = param.grad |
|
|
| assert _grad is not None, f"Gradient is None for layer: {name}" |
| name = re.sub(r"\.weight", "", name) |
| checkpoint_gradients[name] = _grad.detach().cpu() |
|
|
| |
| self.model.zero_grad() |
|
|
| return checkpoint_activations, checkpoint_weights, checkpoint_gradients |
|
|
| |
| |
| |
| |
| |
|
|
| def _setup_forward_hooks(self, checkpoint_activations, checkpoint_weights): |
| """Setup forward hooks for the model to save activations and weights at each layer. |
| |
| This function will setup forward hooks on the layers of the model that we are interested in. |
| The forward hooks will save the activations and weights at each layer whenever the forward pass |
| is performed. |
| |
| Args: |
| checkpoint_activations: A dictionary to store the activations at each layer. |
| checkpoint_weights: A dictionary to store the weights at each layer. |
| |
| Returns: |
| A list of forward hooks. We do this so that we can remove the hooks after the forward pass |
| is complete. |
| """ |
|
|
| forward_hooks = [] |
| layer_suffixes = self.learning_dynamics_config.layer_suffixes |
|
|
| for name, module in self.model.named_modules(): |
| if any(layer_suffix in name for layer_suffix in layer_suffixes): |
| _forward_hook = module.register_forward_hook( |
| self._get_forward_hook( |
| name, checkpoint_activations, checkpoint_weights |
| ) |
| ) |
| forward_hooks.append(_forward_hook) |
| return forward_hooks |
|
|
| def _get_forward_hook( |
| self, module_name, checkpoint_activations, checkpoint_weights |
| ): |
| """Get a forward hook for a given module. |
| |
| This function is called by the _setup_forward_hooks function to setup a forward hook for a given |
| module. This functions is a closure that captures the module_name, checkpoint_activations, and |
| checkpoint_weights. |
| |
| Args: |
| module_name: The name of the module to setup a forward hook for. |
| checkpoint_activations: A dictionary to store the activations at each layer. |
| checkpoint_weights: A dictionary to store the weights at each layer. |
| |
| Returns: |
| A forward hook for the given module. |
| """ |
|
|
| def _forward_hook(module, _, module_out): |
| sequence_idx = self.learning_dynamics_config.sequence_idx |
|
|
| local_activations = module_out[:, sequence_idx, :].detach() |
|
|
| |
| gathered_activations = self.fabric.all_gather(local_activations) |
|
|
| |
| |
| |
| gathered_activations = gathered_activations.transpose(0, 1).reshape( |
| -1, gathered_activations.shape[-1] |
| ) |
|
|
| |
| if module_name not in checkpoint_activations: |
| |
| checkpoint_activations[module_name] = ( |
| gathered_activations.detach().cpu() |
| ) |
|
|
| |
| weight_matrix = module.weight.detach().cpu() |
| checkpoint_weights[module_name] = weight_matrix |
| else: |
| |
| checkpoint_activations[module_name] = torch.cat( |
| ( |
| checkpoint_activations[module_name], |
| gathered_activations.detach().cpu(), |
| ) |
| ) |
|
|
| return _forward_hook |
|
|
|
|
| def compute_learning_dynamics_states( |
| checkpointing_config: CheckpointingConfig, |
| fabric: Fabric, |
| model: nn.Module, |
| dataset: Dataset, |
| compute_gradients: bool = False, |
| ) -> Dict[str, torch.Tensor]: |
| """Computes the learning dynamics metrics for a given checkpoint step. |
| |
| Uses the CheckpointStateExtractor to extract the activations, weights, and optionally gradients |
| of the model at a given checkpoint step. |
| |
| Args: |
| checkpointing_config: The configuration object for checkpointing. |
| fabric: The Fabric instance for distributed training. |
| model: The model to extract states from. |
| dataset: The dataset to extract states from. |
| compute_gradients: Whether to compute the gradients of the model parameters. |
| |
| Returns: |
| A dictionary containing the activations, weights, and optionally gradients of the model. |
| """ |
|
|
| |
| fabric.barrier() |
| model.to("cpu") |
|
|
| |
| def _collate_fn(batch): |
| return {"input_ids": [entry["input_ids"] for entry in batch]} |
|
|
| batch_size = checkpointing_config.learning_dynamics.batch_size |
| sub_batch_size = batch_size // fabric.world_size |
|
|
| |
| |
| |
| |
| extractor_dataloader = DataLoader( |
| dataset, |
| batch_size=sub_batch_size, |
| shuffle=False, |
| collate_fn=_collate_fn, |
| drop_last=False, |
| ) |
| extractor_dataloader = fabric.setup_dataloaders( |
| extractor_dataloader, use_distributed_sampler=True |
| ) |
|
|
| |
| _model = initialize_model(model.config) |
| _model.load_state_dict(model.state_dict()) |
|
|
| if isinstance(fabric.strategy, DeepSpeedStrategy): |
| _model, _ = fabric.setup(_model, DummyOptimizer(_model.parameters())) |
| else: |
| _model = fabric.setup(_model) |
|
|
| _model.zero_grad() |
|
|
| |
| state_extractor = CheckpointStateExtractor( |
| checkpointing_config.learning_dynamics, fabric, _model |
| ) |
|
|
| checkpoint_activations, checkpoint_weights, checkpoint_gradients = ( |
| state_extractor.extract_states( |
| extractor_dataloader, compute_gradients=compute_gradients |
| ) |
| ) |
|
|
| del _model |
| torch.cuda.empty_cache() |
|
|
| |
| fabric.barrier() |
|
|
| model.to(fabric.device) |
|
|
| |
| |
| |
| for layer_name, layer_activations in checkpoint_activations.items(): |
| if len(layer_activations) > len(dataset): |
| checkpoint_activations[layer_name] = layer_activations[: len(dataset)] |
| elif len(layer_activations) < len(dataset): |
| raise ValueError( |
| f"Number of activations ({len(layer_activations)}) in layer {layer_name} does not match number of samples in dataset ({len(dataset)})" |
| ) |
|
|
| return { |
| "activations": checkpoint_activations, |
| "weights": checkpoint_weights, |
| "gradients": checkpoint_gradients, |
| } |
|
|
|
|
| @rank_zero_only |
| @use_backoff() |
| def save_learning_dynamics_states( |
| checkpointing_config: CheckpointingConfig, |
| checkpoint_step: int, |
| prefix: str, |
| fabric: Fabric, |
| learning_dynamics_states: Dict[str, torch.Tensor], |
| learning_dynamics_dataset: Optional[Dataset] = None, |
| tokenizer: Optional[PreTrainedTokenizerBase] = None, |
| ) -> None: |
| """Save the learning dynamics metrics to the checkpointing directory. |
| |
| By default only the learning dynamics states are saved. If the learning dynamics dataset |
| is provided, it is also saved; if a tokenizer is provided, the dataset is also detokenized |
| (i.e. a new column with the text is added to the dataset). |
| |
| The learning dynamics dataset is saved in the checkpointing directory as a HuggingFace |
| dataset. |
| |
| Creates a versioned checkpoint directory with the following structure: |
| |
| {checkpointing_config.runs_dir}/ |
| βββ {checkpointing_config.run_name}/ |
| βββ {checkpointing_config.checkpoints_dir}/ |
| βββ step_{checkpoint_step}/ |
| β βββ {checkpointing_config.learning_dynamics_dir}/ # Learning Dynamics files |
| β βββ {prefix}_activations.pt |
| β βββ {prefix}_weights.pt |
| β βββ {prefix}_gradients.pt |
| β βββ {prefix}_data/ # if learning_dynamics_dataset is provided |
| βββ latest -> step_{checkpoint_step}/ |
| |
| NOTE: this function is only called on rank 0 |
| |
| Args: |
| checkpointing_config: The configuration object for checkpointing. |
| checkpoint_step: The checkpoint step at which the learning dynamics states were computed. |
| prefix: The prefix for the learning dynamics states. |
| fabric: The Fabric instance for distributed training. |
| learning_dynamics_states: The learning dynamics states to save. |
| learning_dynamics_dataset: The dataset containing learning dynamics data, |
| including input IDs that need to be decoded. (optional) |
| tokenizer: The tokenizer used to decode input IDs into text. (optional) |
| """ |
|
|
| runs_dir = checkpointing_config.runs_dir |
| run_name = checkpointing_config.run_name |
| checkpoints_dir = checkpointing_config.checkpoints_dir |
| learning_dynamics_dir = checkpointing_config.learning_dynamics_dir |
|
|
| run_path = os.path.join(runs_dir, run_name) |
| root_checkpoint_path = os.path.join(run_path, checkpoints_dir) |
| checkpoint_path = os.path.join(root_checkpoint_path, f"step_{checkpoint_step}") |
| learning_dynamics_path = os.path.join(checkpoint_path, learning_dynamics_dir) |
| os.makedirs(learning_dynamics_path, exist_ok=True) |
|
|
| |
| for key, value in learning_dynamics_states.items(): |
| if value is not None and len(value) > 0: |
| torch.save( |
| value, os.path.join(learning_dynamics_path, f"{prefix}_{key}.pt") |
| ) |
|
|
| if learning_dynamics_dataset is not None: |
| if tokenizer is not None: |
| |
| detokenized_dataset = {"input_ids": [], "text": []} |
|
|
| for entry in learning_dynamics_dataset: |
| input_ids = entry["input_ids"] |
| decoded_text = tokenizer.decode(input_ids, skip_special_tokens=True) |
| detokenized_dataset["input_ids"].append(input_ids) |
| detokenized_dataset["text"].append(decoded_text) |
|
|
| learning_dynamics_dataset = Dataset.from_dict(detokenized_dataset) |
|
|
| learning_dynamics_dataset_path = os.path.join( |
| learning_dynamics_path, f"{prefix}_data" |
| ) |
| learning_dynamics_dataset.save_to_disk(learning_dynamics_dataset_path) |
|
|
| if checkpointing_config.save_to_hf: |
| |
| upload_folder( |
| folder_path=learning_dynamics_path, |
| path_in_repo=learning_dynamics_dir, |
| repo_id=checkpointing_config.hf_checkpoint.repo_id, |
| commit_message=f"Saving Learning Dynamics Data ({prefix}) -- Step {checkpoint_step}", |
| revision=checkpointing_config.run_name, |
| token=os.getenv("HF_TOKEN"), |
| ) |
|
|