| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """ |
| | Contain small torch utilities |
| | """ |
| |
|
| | from typing import List, Literal, Optional, Tuple, Union |
| |
|
| | import torch |
| | import torch.distributed |
| | import torch.nn.functional as F |
| | from torch.optim.lr_scheduler import LambdaLR |
| |
|
| |
|
| | try: |
| | from flash_attn.ops.triton.cross_entropy import cross_entropy_loss |
| |
|
| | FLAH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE = True |
| | except ImportError: |
| | FLAH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE = False |
| |
|
| |
|
| | @torch.compiler.disable() |
| | def log_probs_from_logits_flash_attn(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: |
| | output = cross_entropy_loss(logits, labels, inplace_backward=True) |
| | if not isinstance(output, tuple): |
| | raise ValueError( |
| | "please make sure flash-attn>=2.4.3 where cross_entropy_loss returns Tuple[losses, z_losses]." |
| | ) |
| |
|
| | return -output[0] |
| |
|
| |
|
| | def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: |
| | """Compute log probs on the label ids given logits. |
| | |
| | We may use torch compile to speed up computing. |
| | |
| | Args: |
| | logits (torch.Tensor): logits of the model, shape (batch_size, seqlen, vocab_size) |
| | labels (torch.Tensor): labels of the model, shape (batch_size, seqlen) |
| | |
| | Returns: |
| | torch.Tensor: log probs of the labels, shape (batch_size, seqlen) |
| | """ |
| | batch_dim = logits.shape[:-1] |
| | vocab_dim = logits.shape[-1] |
| | logits = logits.contiguous().view(-1, vocab_dim) |
| | labels = labels.contiguous().view(-1) |
| | if FLAH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE: |
| | output = log_probs_from_logits_flash_attn(logits, labels) |
| | else: |
| | output = F.cross_entropy(logits.float(), labels, reduction="none") |
| |
|
| | return output.view(*batch_dim) |
| |
|
| |
|
| | def masked_mean(values: torch.Tensor, mask: torch.Tensor, dim: int = None, eps: float = 1e-8) -> torch.Tensor: |
| | """Compute mean of tensor with a masked values.""" |
| | return (values * mask).sum(dim=dim) / (mask.sum(dim=dim) + eps) |
| |
|
| |
|
| | def masked_var(values: torch.Tensor, mask: torch.Tensor, unbiased: bool = True) -> torch.Tensor: |
| | """Compute variance of tensor with masked values.""" |
| | mean = masked_mean(values, mask) |
| | centered_values = values - mean |
| | variance = masked_mean(centered_values**2, mask) |
| | if unbiased: |
| | mask_sum = mask.sum() |
| | if mask_sum <= 1: |
| | print("The sum of the mask is less than one, which can cause a division by zero.") |
| | return variance |
| |
|
| | bessel_correction = mask_sum / (mask_sum - 1) |
| | variance = variance * bessel_correction |
| |
|
| | return variance |
| |
|
| |
|
| | def masked_whiten(values: torch.Tensor, mask: torch.Tensor, eps: float = 1e-8) -> torch.Tensor: |
| | """Whiten values with masked values.""" |
| | mean, var = masked_mean(values, mask), masked_var(values, mask) |
| | return (values - mean) * torch.rsqrt(var + eps) |
| |
|
| |
|
| | def get_response_mask( |
| | response_ids: torch.Tensor, eos_token_id: Union[int, List[int]] = 2, dtype: torch.dtype = torch.long |
| | ): |
| | """Get the mask for the response ids, the mask will be 0 after the first eos token. |
| | |
| | eos_token_id can be int or list: 1 or [1, 2]. |
| | ``` |
| | e.g. eos_token = 1 |
| | response_ids: [0, 0, 2, 4, 3, 5, 1, 0, 0] |
| | response_mask: [1, 1, 1, 1, 1, 1, 1, 0, 0] |
| | ``` |
| | """ |
| | if isinstance(eos_token_id, int): |
| | eos_token_id = [eos_token_id] |
| |
|
| | response_mask = torch.zeros_like(response_ids, dtype=torch.bool) |
| | for token_id in eos_token_id: |
| | response_mask |= response_ids.eq(token_id) |
| |
|
| | response_mask = response_mask.long() |
| | response_mask = (torch.cumsum(response_mask, dim=1) - response_mask).bool() |
| | response_mask = torch.logical_not(response_mask).to(dtype) |
| | return response_mask |
| |
|
| |
|
| | def pad_2d_list_to_length( |
| | response: List[List[int]], pad_token_id: int, max_length: Optional[int] = None |
| | ) -> torch.Tensor: |
| | """Pad a 2D list (e.g. responses, log_probs) to a 2D tensor.""" |
| | max_response_length = max(len(sub_list) for sub_list in response) |
| | if max_length is not None and max_length > max_response_length: |
| | target_length = max_length |
| | else: |
| | target_length = max_response_length |
| |
|
| | padded_response = [tuple(sub_list) + (pad_token_id,) * (target_length - len(sub_list)) for sub_list in response] |
| | tensor = torch.tensor(padded_response) |
| | return tensor |
| |
|
| |
|
| | def pad_sequence_to_length( |
| | tensor: torch.Tensor, max_seq_len: int, pad_token_id: int, left_pad: bool = False |
| | ) -> torch.Tensor: |
| | """Pad a nD tensors in the last dim to max_seq_len.""" |
| | if tensor.size(-1) >= max_seq_len: |
| | return tensor |
| |
|
| | pad_shape = list(tensor.shape) |
| | pad_shape[-1] = max_seq_len - tensor.size(-1) |
| | pad_tensor = torch.full(pad_shape, fill_value=pad_token_id, dtype=tensor.dtype, device=tensor.device) |
| | return torch.cat((pad_tensor, tensor), dim=-1) if left_pad else torch.cat((tensor, pad_tensor), dim=-1) |
| |
|
| |
|
| | def postprocess_data( |
| | input_ids: torch.Tensor, |
| | attention_mask: torch.Tensor, |
| | position_ids: torch.Tensor, |
| | max_length: int, |
| | pad_token_id: int, |
| | left_pad: bool = True, |
| | truncation: Literal["left", "right", "error"] = "error", |
| | ): |
| | """Pad or truncate data.""" |
| | assert truncation in ["left", "right", "error"] |
| | seq_length = len(input_ids) |
| | if seq_length < max_length: |
| | input_ids = pad_sequence_to_length( |
| | input_ids, max_seq_len=max_length, pad_token_id=pad_token_id, left_pad=left_pad |
| | ) |
| | attention_mask = pad_sequence_to_length( |
| | attention_mask, max_seq_len=max_length, pad_token_id=0, left_pad=left_pad |
| | ) |
| | position_ids = pad_sequence_to_length(position_ids, max_seq_len=max_length, pad_token_id=0, left_pad=left_pad) |
| | elif seq_length > max_length: |
| | if truncation == "left": |
| | input_ids = input_ids[..., -max_length:] |
| | attention_mask = attention_mask[..., -max_length:] |
| | position_ids = position_ids[..., -max_length:] |
| | elif truncation == "right": |
| | input_ids = input_ids[..., :max_length] |
| | attention_mask = attention_mask[..., :max_length] |
| | position_ids = position_ids[..., :max_length] |
| | elif truncation == "error": |
| | raise NotImplementedError(f"{seq_length} is larger than {max_length}.") |
| | else: |
| | raise NotImplementedError(f"Unknown truncation method {truncation}.") |
| |
|
| | return input_ids, attention_mask, position_ids |
| |
|
| |
|
| | def get_constant_schedule_with_warmup( |
| | optimizer: torch.optim.Optimizer, |
| | num_warmup_steps: int, |
| | last_epoch: int = -1, |
| | ) -> torch.optim.lr_scheduler.LRScheduler: |
| | """Get the lr scheduler for constant lr.""" |
| |
|
| | def lr_lambda(current_step: int) -> float: |
| | return min(1.0, float(current_step) / float(max(1, num_warmup_steps))) |
| |
|
| | return LambdaLR(optimizer, lr_lambda, last_epoch) |
| |
|
| |
|
| | |
| | class AnyPrecisionAdamW(torch.optim.Optimizer): |
| | def __init__( |
| | self, |
| | params: List[torch.Tensor], |
| | lr: float = 1e-3, |
| | betas: Tuple[float, float] = (0.9, 0.999), |
| | eps: float = 1e-8, |
| | weight_decay: float = 0.0, |
| | use_kahan_summation: bool = True, |
| | momentum_dtype: torch.dtype = torch.bfloat16, |
| | variance_dtype: torch.dtype = torch.bfloat16, |
| | compensation_buffer_dtype: torch.dtype = torch.bfloat16, |
| | ): |
| | """ |
| | Args: |
| | params (iterable): iterable of parameters to optimize or dicts defining parameter groups |
| | lr (float, optional): learning rate (default: 1e-3) |
| | betas (Tuple[float, float], optional): coefficients used for computing |
| | running averages of gradient and its square (default: (0.9, 0.999)) |
| | eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8) |
| | weight_decay (float, optional): weight decay coefficient (default: 1e-2) |
| | |
| | # Any Precision specific |
| | use_kahan_summation = creates auxiliary buffer to ensure high precision |
| | model param updates (default: False) |
| | momentum_dtype = dtype for momentum (default: bfloat16) |
| | variance_dtype = dtype for uncentered variance (default: bfloat16) |
| | compensation_buffer_dtype = dtype for Kahan summation buffer (default: bfloat16) |
| | |
| | # Usage |
| | This optimizer implements optimizer states, and Kahan summation |
| | for high precision updates, all in user controlled dtypes. |
| | Defaults are variance in BF16, Momentum in FP32. |
| | This can be run in FSDP mixed precision, amp, or full precision, |
| | depending on what training pipeline you wish to work with. |
| | |
| | Setting to use_kahan_summation = False, and changing momentum and |
| | variance dtypes to FP32, reverts this to a standard AdamW optimizer. |
| | |
| | """ |
| | defaults = { |
| | "lr": lr, |
| | "betas": betas, |
| | "eps": eps, |
| | "weight_decay": weight_decay, |
| | "use_kahan_summation": use_kahan_summation, |
| | "momentum_dtype": momentum_dtype, |
| | "variance_dtype": variance_dtype, |
| | "compensation_buffer_dtype": compensation_buffer_dtype, |
| | } |
| | super().__init__(params, defaults) |
| |
|
| | @torch.no_grad() |
| | def step(self, closure=None): |
| | """ |
| | Performs a single optimization step. |
| | |
| | Args: |
| | closure (callable, optional): A closure that reevaluates the model and returns the loss. |
| | """ |
| |
|
| | if closure is not None: |
| | with torch.enable_grad(): |
| | closure() |
| |
|
| | for group in self.param_groups: |
| | beta1, beta2 = group["betas"] |
| | lr = group["lr"] |
| | weight_decay = group["weight_decay"] |
| | eps = group["eps"] |
| | use_kahan_summation = group["use_kahan_summation"] |
| |
|
| | momentum_dtype = group["momentum_dtype"] |
| | variance_dtype = group["variance_dtype"] |
| | compensation_buffer_dtype = group["compensation_buffer_dtype"] |
| | for p in group["params"]: |
| | if p.grad is None: |
| | continue |
| |
|
| | if p.grad.is_sparse: |
| | raise RuntimeError("AnyPrecisionAdamW does not support sparse gradients.") |
| |
|
| | state = self.state[p] |
| | |
| | if len(state) == 0: |
| | state["step"] = torch.tensor(0.0) |
| |
|
| | |
| | state["exp_avg"] = torch.zeros_like(p, dtype=momentum_dtype) |
| |
|
| | |
| | state["exp_avg_sq"] = torch.zeros_like(p, dtype=variance_dtype) |
| |
|
| | |
| | if use_kahan_summation: |
| | state["compensation"] = torch.zeros_like(p, dtype=compensation_buffer_dtype) |
| |
|
| | |
| | |
| | state["step"] += 1 |
| | step = state["step"] |
| |
|
| | exp_avg = state["exp_avg"] |
| | exp_avg_sq = state["exp_avg_sq"] |
| | grad = p.grad |
| |
|
| | if weight_decay: |
| | p.data.mul_(1 - lr * weight_decay) |
| |
|
| | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) |
| | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) |
| |
|
| | bias_correction1 = 1 - beta1**step |
| | step_size = lr / bias_correction1 |
| |
|
| | denom_correction = (1 - beta2**step) ** 0.5 |
| | centered_variance = (exp_avg_sq.sqrt() / denom_correction).add_(eps, alpha=1) |
| |
|
| | if use_kahan_summation: |
| | compensation = state["compensation"] |
| | compensation.addcdiv_(exp_avg, centered_variance, value=-step_size) |
| |
|
| | |
| | |
| | temp_buffer = p.detach().clone() |
| | p.data.add_(compensation) |
| | compensation.add_(temp_buffer.sub_(p.data)) |
| | else: |
| | p.data.addcdiv_(exp_avg, centered_variance, value=-step_size) |
| |
|