| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from typing import List, Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers.cache_utils import DynamicCache |
|
|
| from specforge.core.eagle3_adapters import BackendAdapter, SdpaLikeAdapter, UspAdapter |
| from specforge.core.loss import LogSoftmaxLoss |
| from specforge.modeling.draft import Eagle3DraftModel |
| from specforge.utils import padding |
|
|
|
|
| class Eagle3Model(nn.Module): |
| pass |
|
|
|
|
| class OnlineEagle3Model(Eagle3Model): |
| """ |
| In sgl-spec, we implement offline/online training. |
| Online training means we have the target hidden_states available during training. |
| Eagle3 using test time training technique (TTT) to train the draft model. |
| 1. We first extract the hidden states from the target model. |
| 2. Then concatenate the hidden states from 3 aux layers (layer 1, layer num_layers//2, layer num_layers-4). |
| 3. We project the concatenated hidden states to the target hidden size. from (batch, seq_len, 3*hidden_size) to (batch, seq_len, hidden_size) |
| 4. We concat the projected hidden states and embedding output as the input for the draft model. |
| 5. finally, we run TTT to train the draft model. input size is (batch, seq_len, hidden_size * 2) |
| """ |
|
|
| def __init__( |
| self, |
| draft_model: Eagle3DraftModel, |
| length: int = 7, |
| attention_backend="sdpa", |
| target_model: Optional[Eagle3Model] = None, |
| ): |
| """ |
| Args: |
| target_model: the target model to extract hidden states. |
| draft_model: the draft model to be trained. |
| length: TTT length, it means how many turns to unroll during TTT. |
| """ |
| super().__init__() |
| self.draft_model = draft_model |
| self.length = length |
| self.attention_backend = attention_backend |
| self.target_model = target_model |
|
|
| def _make_adapter(self) -> BackendAdapter: |
| if self.attention_backend == "usp": |
| return UspAdapter(self) |
| return SdpaLikeAdapter(self) |
|
|
| def _acc_and_loss( |
| self, |
| *, |
| logits: torch.Tensor, |
| target_p: torch.Tensor, |
| position_mask: torch.Tensor, |
| loss_mask: torch.Tensor, |
| adapter: BackendAdapter, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| with torch.no_grad(): |
| local_correct = ( |
| (logits.argmax(-1) == target_p.argmax(-1)) * position_mask.squeeze(-1) |
| ).sum() |
| local_denom = loss_mask.sum().clamp_min(1e-6) |
| local_correct, local_denom = adapter.reduce_metrics( |
| local_correct=local_correct, local_denom=local_denom |
| ) |
| acc = local_correct / local_denom |
|
|
| loss = LogSoftmaxLoss.apply(logits, target_p, position_mask) |
| loss = adapter.reduce_loss(loss) |
| return acc, loss |
|
|
| def _prepare_position_ids( |
| self, |
| position_ids: Optional[torch.Tensor], |
| *, |
| seq_length: int, |
| past_key_values_length: int, |
| device: torch.device, |
| is_vlm: bool, |
| input_ids: torch.Tensor, |
| image_grid_thw: Optional[torch.Tensor], |
| ) -> torch.Tensor: |
| if self.attention_backend == "usp": |
| return position_ids |
| if position_ids is None: |
| if is_vlm: |
| mrope_positions_ids, _ = self.target_model.get_rope_index( |
| input_ids=input_ids, image_grid_thw=image_grid_thw |
| ) |
| return mrope_positions_ids |
| return ( |
| torch.arange( |
| past_key_values_length, |
| seq_length + past_key_values_length, |
| dtype=torch.long, |
| device=device, |
| ) |
| .unsqueeze(0) |
| .view(-1, seq_length) |
| ) |
|
|
| position_ids = position_ids.long() |
| return position_ids.view(-1, seq_length) |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| target: torch.Tensor, |
| loss_mask: torch.Tensor, |
| hidden_states: torch.Tensor, |
| past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| position_ids: Optional[torch.Tensor] = None, |
| image_grid_thw: Optional[torch.Tensor] = None, |
| is_vlm: bool = False, |
| **kwargs, |
| ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]: |
| """ |
| Online eagle model trainer, modified from: https://github.com/SafeAILab/EAGLE/blob/main/eagle/traineagle3/cnets.py#L711 |
| |
| Args: |
| input_ids: (batch, seq_len) |
| attention_mask: (batch, seq_len) |
| loss_mask: (batch, seq_len) |
| past_key_values: We dont use this past_key_values in eagle3, but keep it for compatibility. We control kvcache by cache_hidden. |
| position_ids: (batch, seq_len) |
| """ |
| |
| target_p_padded, position_mask = _compute_target_p_padded( |
| target=target, |
| t2d=self.draft_model.t2d, |
| loss_mask=loss_mask, |
| length=self.length, |
| ) |
| del target |
| torch.cuda.empty_cache() |
|
|
| |
| batch_size, seq_length, _ = hidden_states.shape |
| seq_length_with_past = seq_length |
| past_key_values_length = 0 |
|
|
| |
| hidden_states = self.draft_model.project_hidden_states(hidden_states) |
|
|
| |
| if past_key_values is not None: |
| past_key_values_length = past_key_values[0][0].shape[2] |
| seq_length_with_past = seq_length_with_past + past_key_values_length |
| position_ids = self._prepare_position_ids( |
| position_ids=position_ids, |
| seq_length=seq_length, |
| past_key_values_length=past_key_values_length, |
| device=hidden_states.device, |
| is_vlm=is_vlm, |
| input_ids=input_ids, |
| image_grid_thw=image_grid_thw, |
| ) |
|
|
| |
| if attention_mask is None: |
| attention_mask = torch.ones( |
| (batch_size, seq_length_with_past), |
| dtype=torch.bool, |
| device=hidden_states.device, |
| ) |
| if self.attention_backend == "sdpa": |
| attention_mask = self.draft_model.prepare_decoder_attention_mask( |
| attention_mask=attention_mask, |
| hidden_states=hidden_states, |
| batch_size=batch_size, |
| seq_length=seq_length, |
| past_key_values_length=past_key_values_length, |
| ) |
|
|
| |
| plosses = [] |
| vlosses = [] |
| acces = [] |
| adapter = self._make_adapter() |
| |
| global_input_ids = input_ids |
| if self.attention_backend in ["sdpa", "fa", "usp"]: |
| cache_hidden = [[], []] |
| past_key_values = None |
| elif self.attention_backend == "flex_attention": |
| cache_hidden = None |
| past_key_values = DynamicCache() |
| else: |
| raise ValueError(f"Unknown attention backend: {self.attention_backend}") |
|
|
| for idx in range(self.length): |
| state = adapter.step_view( |
| idx=idx, |
| ttt_length=self.length, |
| global_input_ids=global_input_ids, |
| attention_mask=attention_mask, |
| loss_mask=loss_mask, |
| position_ids=position_ids, |
| hidden_states=hidden_states, |
| target_p_padded=target_p_padded, |
| position_mask=position_mask, |
| seq_length=seq_length, |
| ) |
| is_last = idx == self.length - 1 |
|
|
| |
| inputs_embeds = self.draft_model.embed_input_ids(state.input_ids) |
| inputs_embeds = inputs_embeds.to(hidden_states.dtype) |
|
|
| |
| hidden_states_out = self.draft_model.backbone( |
| input_embeds=inputs_embeds, |
| hidden_states=state.hidden_states, |
| cache_hidden=cache_hidden, |
| attention_mask=state.attention_mask, |
| position_ids=state.position_ids, |
| past_key_values=past_key_values, |
| use_cache=True, |
| ) |
|
|
| |
| hidden_states = hidden_states_out |
|
|
| |
| logits = self.draft_model.compute_logits(hidden_states) |
|
|
| |
| acc, loss = self._acc_and_loss( |
| logits=logits, |
| target_p=state.target_p, |
| position_mask=state.position_mask, |
| loss_mask=state.loss_mask, |
| adapter=adapter, |
| ) |
| acces.append(acc) |
| plosses.append(loss) |
|
|
| if not is_last: |
| |
| global_input_ids = padding(global_input_ids, left=False) |
| position_mask = padding(position_mask, left=False) |
| loss_mask = padding(loss_mask, left=False) |
| |
| return plosses, vlosses, acces |
|
|
|
|
| class QwenVLOnlineEagle3Model(Eagle3Model): |
| """ |
| In sgl-spec, we implement offline/online training. |
| Online training means we have the target hidden_states available during training. |
| Eagle3 using test time training technique (TTT) to train the draft model. |
| 1. We first extract the hidden states from the target model. |
| 2. Then concatenate the hidden states from 3 aux layers (layer 1, layer num_layers//2, layer num_layers-4). |
| 3. We project the concatenated hidden states to the target hidden size. from (batch, seq_len, 3*hidden_size) to (batch, seq_len, hidden_size) |
| 4. We concat the projected hidden states and embedding output as the input for the draft model. |
| 5. finally, we run TTT to train the draft model. input size is (batch, seq_len, hidden_size * 2) |
| """ |
|
|
| def __init__( |
| self, |
| target_model, |
| draft_model: Eagle3DraftModel, |
| processor, |
| length: int = 7, |
| attention_backend: str = "sdpa", |
| ): |
| """ |
| Args: |
| target_model: the target model to extract hidden states. |
| draft_model: the draft model to be trained. |
| length: TTT length, it means how many turns to unroll during TTT. |
| """ |
| super().__init__() |
| self.target_model = target_model |
| self.draft_model = draft_model |
| self.processor = processor |
| self.length = length |
| self.attention_backend = attention_backend |
|
|
| @torch.no_grad() |
| def _prepare_data( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| loss_mask: torch.Tensor, |
| pixel_values: Optional[torch.Tensor] = None, |
| image_grid_thw: Optional[torch.Tensor] = None, |
| device: Optional[torch.device] = None, |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| """ |
| modified from: https://github.com/SafeAILab/EAGLE/blob/main/eagle/traineagle3/cnets.py#L692 |
| Extract the hidden states from the target model outputs. |
| |
| Args: |
| input_ids: (batch, seq_len) |
| attention_mask: (batch, seq_len) |
| loss_mask: (batch, seq_len) |
| device: the device to run the target model, if None, use the input_ids device |
| pixel_values: image pixel values, used for VLM models |
| image_grid_thw: image grid thw, used for VLM models |
| |
| Returns: |
| hidden_states: (batch, seq_len, 3*hidden_size) |
| target: (batch, seq_len, vocab_size) |
| loss_mask: (batch, seq_len) |
| input_ids: (batch, seq_len) |
| """ |
|
|
| if device is None: |
| device = input_ids.device |
|
|
| |
| outputs = self.target_model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| pixel_values=pixel_values, |
| image_grid_thw=image_grid_thw, |
| output_hidden_states=True, |
| use_cache=False, |
| ) |
|
|
| |
| |
| |
| num_hidden_states = len(outputs.hidden_states) |
| offset = 1 |
| num_layers = num_hidden_states - 1 |
|
|
| |
| low_aux_layer = 1 + offset |
| mid_aux_layer = num_layers // 2 - 1 + offset |
| last_aux_layer = num_layers - 4 + offset |
|
|
| hidden_states0 = outputs.hidden_states[low_aux_layer] |
| hidden_states1 = outputs.hidden_states[mid_aux_layer] |
| hidden_states2 = outputs.hidden_states[last_aux_layer] |
|
|
| hidden_states = torch.cat( |
| (hidden_states0, hidden_states1, hidden_states2), dim=-1 |
| ) |
|
|
| |
| target = outputs.logits |
| target = padding(target, left=False) |
| input_ids = padding(input_ids, left=False) |
|
|
| if target is not None: |
| target = target.to(device) |
| loss_mask = loss_mask[..., None] |
| loss_mask = loss_mask.to(device) |
|
|
| return hidden_states, target, loss_mask, input_ids |
|
|
| @torch.no_grad() |
| def _get_input_embeds( |
| self, |
| input_ids: torch.Tensor, |
| pixel_values: torch.Tensor, |
| image_grid_thw: torch.Tensor, |
| ) -> torch.Tensor: |
| |
| |
| inputs_embeds = self.draft_model.embed_input_ids(input_ids) |
| image_embeds = self.target_model.model.get_image_features( |
| pixel_values, image_grid_thw |
| ) |
| image_embeds = torch.cat(image_embeds, dim=0) |
| n_image_tokens = ( |
| input_ids == self.target_model.model.config.image_token_id |
| ).sum() |
| n_image_features = image_embeds.shape[0] |
| if n_image_tokens != n_image_features: |
| raise ValueError( |
| f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" |
| ) |
|
|
| mask = input_ids == self.target_model.model.config.image_token_id |
| mask_unsqueezed = mask.unsqueeze(-1) |
| mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) |
| image_mask = mask_expanded.to(inputs_embeds.device) |
|
|
| image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) |
| inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) |
| return inputs_embeds |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| loss_mask: torch.Tensor, |
| past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| position_ids: Optional[torch.Tensor] = None, |
| pixel_values: Optional[torch.Tensor] = None, |
| image_grid_thw: Optional[torch.Tensor] = None, |
| ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]: |
| """ |
| Online eagle model trainer, modified from: https://github.com/SafeAILab/EAGLE/blob/main/eagle/traineagle3/cnets.py#L711 |
| |
| Args: |
| input_ids: (batch, seq_len) |
| attention_mask: (batch, seq_len) |
| loss_mask: (batch, seq_len) |
| past_key_values: We dont use this past_key_values in eagle3, but keep it for compatibility. We control kvcache by cache_hidden. |
| position_ids: (batch, seq_len) |
| pixel_values: batch image pixel values, used for VLM models |
| image_grid_thw: (batch, 3), image grid thw, used for VLM models |
| """ |
| |
| hidden_states, target, loss_mask, input_ids = self._prepare_data( |
| input_ids, attention_mask, loss_mask, pixel_values, image_grid_thw |
| ) |
|
|
| |
| target_p_padded, position_mask = _compute_target_p_padded( |
| target=target, |
| t2d=self.draft_model.t2d, |
| loss_mask=loss_mask, |
| length=self.length, |
| ) |
| del target |
|
|
| |
| batch_size, seq_length, _ = hidden_states.shape |
| seq_length_with_past = seq_length |
| past_key_values_length = 0 |
|
|
| |
| hidden_states = self.draft_model.project_hidden_states(hidden_states) |
|
|
| |
| if past_key_values is not None: |
| past_key_values_length = past_key_values[0][0].shape[2] |
| seq_length_with_past = seq_length_with_past + past_key_values_length |
|
|
| if position_ids is None: |
| attention_mask_tensor = ( |
| attention_mask |
| if not isinstance(attention_mask, dict) |
| else attention_mask["full_attention"] |
| ) |
| if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4: |
| attention_mask_tensor = torch.diagonal( |
| attention_mask_tensor[:, 0], dim1=1, dim2=2 |
| ) |
| attention_mask_tensor = ( |
| attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min |
| ) |
| attention_mask_tensor = (1.0 - attention_mask_tensor).int() |
|
|
| position_ids, rope_deltas = self.target_model.model.get_rope_index( |
| input_ids, |
| image_grid_thw, |
| None, |
| second_per_grid_ts=None, |
| attention_mask=attention_mask_tensor, |
| ) |
| self.rope_deltas = rope_deltas |
| else: |
| position_ids = position_ids |
|
|
| |
| if attention_mask is None: |
| attention_mask = torch.ones( |
| (batch_size, seq_length_with_past), |
| dtype=torch.bool, |
| device=hidden_states.device, |
| ) |
| if self.attention_backend == "sdpa": |
| attention_mask = self.draft_model.prepare_decoder_attention_mask( |
| attention_mask=attention_mask, |
| hidden_states=hidden_states, |
| batch_size=batch_size, |
| seq_length=seq_length, |
| past_key_values_length=past_key_values_length, |
| ) |
|
|
| |
| plosses = [] |
| vlosses = [] |
| acces = [] |
| if self.attention_backend in ["sdpa", "fa"]: |
| cache_hidden = [[], []] |
| past_key_values = None |
| elif self.attention_backend == "flex_attention": |
| cache_hidden = None |
| past_key_values = DynamicCache() |
| else: |
| raise ValueError(f"Unknown attention backend: {self.attention_backend}") |
|
|
| for idx in range(self.length): |
| target_p = target_p_padded[:, idx : idx + seq_length, :].contiguous() |
| is_last = idx == self.length - 1 |
|
|
| |
| |
| inputs_embeds = self.draft_model.embed_input_ids(input_ids) |
| inputs_embeds = inputs_embeds.to(hidden_states.dtype) |
|
|
| |
| hidden_states_out = self.draft_model.backbone( |
| input_embeds=inputs_embeds, |
| hidden_states=hidden_states, |
| cache_hidden=cache_hidden, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| use_cache=True, |
| ) |
|
|
| |
| hidden_states = hidden_states_out |
|
|
| |
| logits = self.draft_model.compute_logits(hidden_states) |
|
|
| |
| with torch.no_grad(): |
| acces.append( |
| _compute_metric_acc( |
| logits=logits, |
| target_p=target_p, |
| position_mask=position_mask, |
| loss_mask=loss_mask, |
| ) |
| ) |
|
|
| |
| loss = LogSoftmaxLoss.apply(logits, target_p, position_mask) |
| plosses.append(loss) |
|
|
| if not is_last: |
| |
| input_ids = padding(input_ids, left=False) |
| position_mask = padding(position_mask, left=False) |
| loss_mask = padding(loss_mask, left=False) |
| |
| return plosses, vlosses, acces |
|
|
|
|
| def _compute_target_p_padded(target, t2d, loss_mask, length): |
| with torch.no_grad(): |
| target_p, position_mask = _compute_target_p( |
| target=target, |
| t2d=t2d, |
| loss_mask=loss_mask, |
| ) |
|
|
| assert len(target_p.shape) == 3 |
| target_p_padded = F.pad( |
| target_p, |
| pad=(0, 0, 0, length), |
| mode="constant", |
| |
| value=1 / target_p.shape[-1], |
| ) |
|
|
| return target_p_padded, position_mask |
|
|
|
|
| @torch.compile(dynamic=None) |
| def _compute_target_p(target, t2d, loss_mask): |
| target_head = target |
| target_max_token = target_head.argmax(-1) |
| target_mask = t2d[target_max_token] |
| target_mask = target_mask[..., None].int() |
| position_mask = target_mask * loss_mask |
| target_head = target_head[..., t2d] |
| target_head = target_head.float() |
| target_p = nn.Softmax(dim=2)(target_head) |
| target_p = target_p.detach() |
| return target_p, position_mask |
|
|
|
|
| @torch.compile(dynamic=None) |
| def _compute_metric_acc(logits, target_p, position_mask, loss_mask): |
| return ( |
| (logits.argmax(-1) == target_p.argmax(-1)) * position_mask.squeeze(-1) |
| ).sum() / loss_mask.sum().clamp_min(1e-6) |
|
|