| | import re |
| | import warnings |
| | from typing import Dict, Optional, Union |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from transformers import AutoConfig, AutoModel, PretrainedConfig |
| | from transformers.modeling_outputs import ( |
| | BaseModelOutput, |
| | BaseModelOutputWithPooling, |
| | BaseModelOutputWithPoolingAndCrossAttentions, |
| | ) |
| |
|
| | _HF_ARCH_DICT = { |
| | |
| | 'roberta': { |
| | 'config_names': { |
| | 'context_length': 'max_position_embeddings', |
| | 'vocab_size': 'vocab_size', |
| | 'width': 'hidden_size', |
| | 'heads': 'num_attention_heads', |
| | 'layers': 'num_hidden_layers', |
| | 'layer_attr': 'layer', |
| | 'token_embeddings_attr': 'embeddings', |
| | }, |
| | 'pooler': 'mean_pooler', |
| | }, |
| | |
| | 'xlm-roberta': { |
| | 'config_names': { |
| | 'context_length': 'max_position_embeddings', |
| | 'vocab_size': 'vocab_size', |
| | 'width': 'hidden_size', |
| | 'heads': 'num_attention_heads', |
| | 'layers': 'num_hidden_layers', |
| | 'layer_attr': 'layer', |
| | 'token_embeddings_attr': 'embeddings', |
| | }, |
| | 'pooler': 'mean_pooler', |
| | }, |
| | |
| | 'bert': { |
| | 'config_names': { |
| | 'context_length': 'max_position_embeddings', |
| | 'vocab_size': 'vocab_size', |
| | 'width': 'hidden_size', |
| | 'heads': 'num_attention_heads', |
| | 'layers': 'num_hidden_layers', |
| | }, |
| | 'pooler': 'cls_pooler', |
| | }, |
| | } |
| |
|
| | _POOLERS = {} |
| |
|
| |
|
| | def _camel2snake(s): |
| | return re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower() |
| |
|
| |
|
| | def register_pooler(cls): |
| | """Decorator registering pooler class""" |
| | _POOLERS[_camel2snake(cls.__name__)] = cls |
| | return cls |
| |
|
| |
|
| | @register_pooler |
| | class MeanPooler(nn.Module): |
| | @staticmethod |
| | def forward(x: BaseModelOutput, attention_mask: torch.Tensor): |
| | masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1) |
| | return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True) |
| |
|
| |
|
| | @register_pooler |
| | class MaxPooler(nn.Module): |
| | @staticmethod |
| | def forward(x: BaseModelOutput, attention_mask: torch.Tensor): |
| | masked_output = x.last_hidden_state.masked_fill( |
| | attention_mask.unsqueeze(-1), -torch.inf |
| | ) |
| | return masked_output.max(1).values |
| |
|
| |
|
| | @register_pooler |
| | class ClsPooler(nn.Module): |
| | def __init__(self, use_pooler_output: bool = True): |
| | super().__init__() |
| | self.cls_token_position = 0 |
| | self.use_pooler_output = use_pooler_output |
| |
|
| | def forward(self, x: BaseModelOutput, _: torch.Tensor): |
| | if ( |
| | self.use_pooler_output |
| | and isinstance( |
| | x, |
| | ( |
| | BaseModelOutputWithPooling, |
| | BaseModelOutputWithPoolingAndCrossAttentions, |
| | ), |
| | ) |
| | and (x.pooler_output is not None) |
| | ): |
| | return x.pooler_output |
| | return x.last_hidden_state[:, self.cls_token_position, :] |
| |
|
| |
|
| | class HFTextEncoder(nn.Module): |
| | output_tokens: torch.jit.Final[bool] |
| |
|
| | def __init__( |
| | self, |
| | model_name_or_path: str, |
| | output_dim: int, |
| | config: PretrainedConfig = None, |
| | pooler_type: str = None, |
| | proj_type: str = None, |
| | proj_bias: bool = False, |
| | pretrained: bool = True, |
| | output_tokens: bool = False, |
| | trust_remote_code: bool = False, |
| | revision: Optional[str] = None, |
| | code_revision: Optional[str] = None, |
| | default_instruction_task: Optional[str] = None, |
| | default_lora_task: Optional[str] = None, |
| | model_config_kwargs: Optional[Dict] = None, |
| | ): |
| | super().__init__() |
| | self.output_tokens = output_tokens |
| | self.output_dim = output_dim |
| |
|
| | model_config_kwargs = model_config_kwargs or {} |
| |
|
| | if config is None: |
| | if pretrained: |
| | self.transformer = AutoModel.from_pretrained( |
| | model_name_or_path, |
| | trust_remote_code=trust_remote_code, |
| | revision=revision, |
| | add_pooling_layer=False, |
| | code_revision=code_revision, |
| | **model_config_kwargs, |
| | ) |
| | self.config = self.transformer.config |
| | else: |
| | self.config = AutoConfig.from_pretrained( |
| | model_name_or_path, |
| | trust_remote_code=trust_remote_code, |
| | code_revision=code_revision, |
| | ) |
| | self.config.update(model_config_kwargs) |
| | self.transformer = AutoModel.from_config( |
| | self.config, |
| | trust_remote_code=trust_remote_code, |
| | add_pooling_layer=False, |
| | code_revision=code_revision, |
| | ) |
| | if ( |
| | hasattr(self.config, 'is_encoder_decoder') |
| | and self.config.is_encoder_decoder |
| | ): |
| | self.transformer = self.transformer.encoder |
| |
|
| | else: |
| | self.config = config |
| | self.config.update(model_config_kwargs) |
| | self.transformer = AutoModel.from_config( |
| | self.config, |
| | trust_remote_code=trust_remote_code, |
| | revision=revision, |
| | code_revision=code_revision, |
| | ) |
| | self.vocab_size = getattr(self.config, 'vocab_size', 0) |
| | self.context_length = getattr(self.config, 'max_position_embeddings', 0) |
| |
|
| | pooler_type = pooler_type or _HF_ARCH_DICT[self.config.model_type]['pooler'] |
| | self.pooler = _POOLERS[pooler_type]() |
| |
|
| | d_model = getattr( |
| | self.config, _HF_ARCH_DICT[self.config.model_type]['config_names']['width'] |
| | ) |
| | if (d_model == output_dim) and (proj_type is None): |
| | self.proj = nn.Identity() |
| | elif (d_model != output_dim) or proj_type == 'linear': |
| | self.proj = nn.Linear(d_model, output_dim, bias=proj_bias) |
| | elif proj_type == 'mlp': |
| | hidden_size = (d_model + output_dim) // 2 |
| | self.proj = nn.Sequential( |
| | nn.Linear(d_model, hidden_size, bias=proj_bias), |
| | nn.GELU(), |
| | nn.Linear(hidden_size, output_dim, bias=proj_bias), |
| | ) |
| |
|
| | self._task_instructions = {} |
| | self._lora_adaptation_map = {} |
| | self._supports_task_instructions = False |
| | self._supports_lora = False |
| | if ( |
| | hasattr(self.transformer, '_adaptation_map') |
| | and len(self.transformer._adaptation_map) > 0 |
| | ): |
| | self._lora_adaptation_map = self.transformer._adaptation_map |
| | self._supports_lora = True |
| | if ( |
| | hasattr(self.transformer, '_task_instructions') |
| | and len(self.transformer._task_instructions) > 0 |
| | ): |
| | self._task_instructions = self.transformer._task_instructions |
| | self._supports_task_instructions = True |
| |
|
| | self._default_instruction_task = None |
| | self._default_lora_task = None |
| | self._default_instruction = None |
| | self._default_loraid = None |
| |
|
| | if default_instruction_task is not None: |
| | self._default_instruction_task = default_instruction_task |
| | self._default_instruction = self.get_instruction_from_task( |
| | default_instruction_task |
| | ) |
| | if default_lora_task is not None: |
| | self._default_lora_task = default_lora_task |
| | self._default_loraid = self.get_loraid_from_task(default_lora_task) |
| |
|
| | @property |
| | def supports_task_instructions(self) -> bool: |
| | return self._supports_task_instructions |
| |
|
| | @property |
| | def supports_lora(self) -> bool: |
| | return self._supports_lora |
| |
|
| | @property |
| | def task_instructions(self) -> Dict[str, str]: |
| | return self._task_instructions |
| |
|
| | @property |
| | def lora_adaptation_map(self) -> Dict[str, int]: |
| | return self._lora_adaptation_map |
| |
|
| | @property |
| | def default_instruction(self) -> Optional[str]: |
| | return self._default_instruction |
| |
|
| | @property |
| | def default_loraid(self) -> Optional[int]: |
| | return self._default_loraid |
| |
|
| | def get_instruction_from_task(self, task: Optional[str]) -> Optional[str]: |
| | if self._supports_task_instructions: |
| | if task is None: |
| | return self._default_instruction |
| | if task not in self._task_instructions: |
| | raise ValueError( |
| | f'Unsupported task \'{task}\'. Choose one of the following: ' |
| | f'{", ".join(self._task_instructions)} or set to None to disable ' |
| | f'task instructions completely' |
| | ) |
| | return self._task_instructions[task] |
| | else: |
| | if task is not None: |
| | warnings.warn( |
| | 'Model does not support task instructions, ignoring instruction ' |
| | f"task '{task}'" |
| | ) |
| | return None |
| |
|
| | def get_loraid_from_task(self, task: Optional[str]) -> Optional[int]: |
| | if self._supports_lora: |
| | if task is None: |
| | return self._default_loraid |
| | if task not in self._lora_adaptation_map: |
| | raise ValueError( |
| | f'Unsupported task \'{task}\'. Choose one of the following: ' |
| | f'{", ".join(self._task_instructions)} or set to None to disable ' |
| | f'the LoRA adapters completely' |
| | ) |
| | return self._lora_adaptation_map[task] |
| | else: |
| | if task is not None: |
| | warnings.warn( |
| | f"Model does not support LoRA adapters, ignoring LoRA task '{task}'" |
| | ) |
| | return None |
| |
|
| | @staticmethod |
| | def get_adapter_mask_from_loraid( |
| | batch_size: int, loraid: int, device: Union[str, torch.device] |
| | ): |
| | return torch.full((batch_size,), loraid, dtype=torch.int32, device=device) |
| |
|
| | @torch.jit.ignore |
| | def set_grad_checkpointing(self, _=True): |
| | self.transformer.gradient_checkpointing_enable() |
| |
|
| | def init_parameters(self): |
| | pass |
| |
|
| | def forward(self, x: torch.Tensor, adapter_mask: Optional[torch.Tensor] = None): |
| | if adapter_mask is None: |
| | default_loraid = self.default_loraid |
| | if default_loraid is not None: |
| | adapter_mask = self.get_adapter_mask_from_loraid( |
| | x.shape[0], default_loraid, x.device |
| | ) |
| | else: |
| | if not self.supports_lora: |
| | warnings.warn( |
| | 'Model does not support LoRA adapters, setting adapter_mask to None' |
| | ) |
| | adapter_mask = None |
| |
|
| | attention_mask = (x != self.config.pad_token_id).long() |
| | lora_kwargs = {} |
| | if adapter_mask is not None: |
| | lora_kwargs['adapter_mask'] = adapter_mask |
| |
|
| | out = self.transformer( |
| | input_ids=x, attention_mask=attention_mask, **lora_kwargs |
| | ) |
| | pooled_out = self.pooler(out, attention_mask) |
| | projected = self.proj(pooled_out) |
| | seqlen = out.last_hidden_state.shape[1] |
| | tokens = ( |
| | out.last_hidden_state[ |
| | :, torch.arange(seqlen) != self.pooler.cls_token_position, : |
| | ] |
| | if isinstance(self.pooler, ClsPooler) |
| | else out.last_hidden_state |
| | ) |
| | if self.output_tokens: |
| | return projected, tokens |
| | return projected |
| |
|
| | def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True): |
| | if not unlocked_layers: |
| | for n, p in self.transformer.named_parameters(): |
| | p.requires_grad = ( |
| | (not freeze_layer_norm) if 'LayerNorm' in n.split('.') else False |
| | ) |
| | return |
| |
|
| | encoder = ( |
| | self.transformer.encoder |
| | if hasattr(self.transformer, 'encoder') |
| | else self.transformer |
| | ) |
| | layer_list = getattr( |
| | encoder, _HF_ARCH_DICT[self.config.model_type]['config_names']['layer_attr'] |
| | ) |
| | print(f'Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model') |
| | embeddings = getattr( |
| | self.transformer, |
| | _HF_ARCH_DICT[self.config.model_type]['config_names'][ |
| | 'token_embeddings_attr' |
| | ], |
| | ) |
| | modules = [embeddings, *layer_list][:-unlocked_layers] |
| | |
| | for module in modules: |
| | for n, p in module.named_parameters(): |
| | p.requires_grad = ( |
| | (not freeze_layer_norm) if 'LayerNorm' in n.split('.') else False |
| | ) |
| |
|